ES_check_inputs <- function(long_data, outcomevar, unit_var, cal_time_var, onset_time_var,
                            cluster_vars, omitted_event_time = -2, anticipation = 0,
                            min_control_gap = 1, max_control_gap = Inf, linearize_pretrends = FALSE,
                            fill_zeros = FALSE, residualize_covariates = FALSE, control_subset_var = NA,
                            control_subset_event_time = 0, treated_subset_var = NA,
                            treated_subset_event_time = 0, control_subset_var2 = NA,
                            control_subset_event_time2 = 0, treated_subset_var2 = NA,
                            treated_subset_event_time2 = 0, discrete_covars = NULL,
                            cont_covars = NULL, never_treat_action = "none", homogeneous_ATT = FALSE,
                            reg_weights = NULL, add_unit_fes = FALSE, bootstrapES = FALSE,
                            bootstrap_iters = 1, bootstrap_num_cores = 1, keep_all_bootstrap_results = FALSE,
                            ipw = FALSE, ipw_model = "linear", ipw_composition_change = FALSE,
                            ipw_keep_data = FALSE, ipw_ps_lower_bound = 0, ipw_ps_upper_bound = 1,
                            event_vs_noevent = FALSE, ref_discrete_covars = NULL, ref_discrete_covar_event_time = 0,
                            ref_cont_covars = NULL, ref_cont_covar_event_time = 0, calculate_collapse_estimates = FALSE,
                            collapse_inputs = NULL, ref_reg_weights = NULL, ref_reg_weights_event_time = 0,
                            ntile_var = NULL, ntile_event_time = -2, ntiles = NA, ntile_var_value = NA,
                            ntile_avg = FALSE, endog_var = NULL, linearDiD = FALSE,
                            linearDiD_treat_var = NULL, cohort_by_cohort = FALSE, cohort_by_cohort_num_cores = 1,
                            heterogeneous_only = FALSE, ref_event_time_mean_vars = NULL) {
    assertDataTable(long_data)
    assertCharacter(outcomevar, len = 1)
    assertCharacter(unit_var, len = 1)
    assertCharacter(cal_time_var, len = 1)
    assertCharacter(onset_time_var, len = 1)
    if (!is.null(cluster_vars)) {
        assertCharacter(cluster_vars)
    }
    assertIntegerish(omitted_event_time, len = 1, upper = -1)
    assertIntegerish(anticipation, len = 1, lower = 0)
    assertIntegerish(min_control_gap, len = 1, lower = 1)
    if (!any(
        testIntegerish(max_control_gap, len = 1, lower = min_control_gap),
        is.infinite(max_control_gap)
    )) {
        assertIntegerish(max_control_gap, len = 1, lower = min_control_gap)
    }
    if (!any(
        testIntegerish(max_control_gap, len = 1, lower = min_control_gap),
        is.infinite(max_control_gap)
    )) {
        assertIntegerish(max_control_gap, len = 1, lower = min_control_gap)
    }
    if (!is.na(control_subset_var)) {
        assertCharacter(control_subset_var, len = 1)
    }
    if (!is.na(treated_subset_var)) {
        assertCharacter(treated_subset_var, len = 1)
    }
    if (!is.na(control_subset_var2)) {
        assertCharacter(control_subset_var2, len = 1)
    }
    if (!is.na(treated_subset_var2)) {
        assertCharacter(treated_subset_var2, len = 1)
    }
    if (!is.null(reg_weights)) {
        assertCharacter(reg_weights, len = 1)
    }
    if (!is.null(ref_reg_weights)) {
        assertCharacter(ref_reg_weights, len = 1)
    }
    assertIntegerish(control_subset_event_time, len = 1)
    assertIntegerish(treated_subset_event_time, len = 1)
    assertIntegerish(control_subset_event_time2, len = 1)
    assertIntegerish(treated_subset_event_time2, len = 1)
    assertIntegerish(ref_discrete_covar_event_time, len = 1)
    assertIntegerish(ref_cont_covar_event_time, len = 1)
    assertIntegerish(ref_reg_weights_event_time, len = 1)
    assertFlag(linearize_pretrends)
    assertFlag(fill_zeros)
    assertFlag(residualize_covariates)
    if (residualize_covariates) {
        if (!any(testCharacter(discrete_covars), testCharacter(cont_covars))) {
            stop("Since residualize_covariates=TRUE, either discrete_covars or cont_covars must be provided as a character vector.")
        }
    }
    assertFlag(homogeneous_ATT)
    assertFlag(add_unit_fes)
    assertFlag(bootstrapES)
    assertIntegerish(bootstrap_iters, len = 1, lower = 1)
    assertIntegerish(bootstrap_num_cores, len = 1, lower = 1)
    assertFlag(keep_all_bootstrap_results)
    assertFlag(ipw)
    assertFlag(ipw_composition_change)
    assertNumber(ipw_ps_lower_bound, lower = 0, upper = 1)
    assertNumber(ipw_ps_upper_bound, lower = 0, upper = 1)
    assertFlag(calculate_collapse_estimates)
    if (calculate_collapse_estimates == TRUE) {
        assertDataTable(collapse_inputs,
            ncols = 2, any.missing = FALSE,
            types = c("character", "list")
        )
    }
    if (!is.null(ntile_var)) {
        assertIntegerish(ntile_event_time, len = 1)
        assertIntegerish(ntiles, len = 1, lower = 2)
        assertIntegerish(ntile_var_value,
            len = 1, lower = 1,
            upper = ntiles
        )
        assertFlag(ntile_avg)
    }
    if (!is.null(endog_var)) {
        assertCharacter(endog_var, len = 1)
    }
    assertFlag(linearDiD)
    if ((!is.null(linearDiD_treat_var) & linearDiD == FALSE)) {
        stop(sprintf(
            "Supplied linearDiD_treat_var='%s' but either didn't set linearDiD or set linearDiD='%s'. \n Let me suggest linearDiD=TRUE.",
            linearDiD_treat_var, linearDiD
        ))
    }
    assertFlag(cohort_by_cohort)
    assertIntegerish(cohort_by_cohort_num_cores, len = 1, lower = 1)
    assertFlag(heterogeneous_only)
    if (homogeneous_ATT == TRUE & heterogeneous_only == TRUE) {
        stop(sprintf(
            "homogeneous_ATT='%s' and heterogeneous_only='%s', which logically conflict with one another.",
            homogeneous_ATT, heterogeneous_only
        ))
    }
    if (omitted_event_time + anticipation > -1) {
        stop(sprintf(
            "omitted_event_time='%s' and anticipation='%s' implies overlap of pre-treatment and anticipation periods. Let me suggest omitted_event_time<='%s'",
            omitted_event_time, anticipation, ((-1 * anticipation) -
                1)
        ))
    }
    if (!(outcomevar %in% names(long_data))) {
        stop(sprintf(
            "Variable outcomevar='%s' is not in the long_data you provided.",
            outcomevar
        ))
    }
    if (!(unit_var %in% names(long_data))) {
        stop(sprintf(
            "Variable unit_var='%s' is not in the long_data you provided.",
            unit_var
        ))
    }
    if (!(cal_time_var %in% names(long_data))) {
        stop(sprintf(
            "Variable cal_time_var='%s' is not in the long_data you provided.",
            cal_time_var
        ))
    }
    if (!(long_data[, typeof(get(cal_time_var))] == "integer")) {
        stop(sprintf(
            "Variable cal_time_var='%s' must be of type 'integer' in long_data.",
            cal_time_var
        ))
    }
    if (!(onset_time_var %in% names(long_data))) {
        stop(sprintf(
            "Variable onset_time_var='%s' is not in the long_data you provided.",
            onset_time_var
        ))
    }
    if (!(long_data[, typeof(get(onset_time_var))] == "integer")) {
        stop(sprintf(
            "Variable onset_time_var='%s' must be of type 'integer' in long_data.",
            onset_time_var
        ))
    }
    if (!is.null(cluster_vars)) {
        for (vv in cluster_vars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable cluster_vars='%s' is not in the long_data you provided. Let me suggest cluster_vars='%s'.",
                    vv, unit_var
                ))
            }
        }
    }
    if (!is.na(control_subset_var)) {
        if (!(control_subset_var %in% names(long_data))) {
            stop(sprintf(
                "Variable control_subset_var='%s' is not in the long_data you provided.",
                control_subset_var
            ))
        }
        if (!(long_data[, typeof(get(control_subset_var))] ==
            "logical")) {
            stop(sprintf(
                "Variable control_subset_var='%s' must be of type 'logical' in long_data (i.e., only TRUE or FALSE values).",
                control_subset_var
            ))
        }
    }
    if (!is.na(treated_subset_var)) {
        if (!(treated_subset_var %in% names(long_data))) {
            stop(sprintf(
                "Variable treated_subset_var='%s' is not in the long_data you provided.",
                treated_subset_var
            ))
        }
        if (!(long_data[, typeof(get(treated_subset_var))] ==
            "logical")) {
            stop(sprintf(
                "Variable treated_subset_var='%s' must be of type 'logical' in long_data (i.e., only TRUE or FALSE values).",
                treated_subset_var
            ))
        }
    }
    if (!is.na(control_subset_var2)) {
        if (!(control_subset_var2 %in% names(long_data))) {
            stop(sprintf(
                "Variable control_subset_var2='%s' is not in the long_data you provided.",
                control_subset_var2
            ))
        }
        if (!(long_data[, typeof(get(control_subset_var2))] ==
            "logical")) {
            stop(sprintf(
                "Variable control_subset_var2='%s' must be of type 'logical' in long_data (i.e., only TRUE or FALSE values).",
                control_subset_var2
            ))
        }
    }
    if (!is.na(treated_subset_var2)) {
        if (!(treated_subset_var2 %in% names(long_data))) {
            stop(sprintf(
                "Variable treated_subset_var2='%s' is not in the long_data you provided.",
                treated_subset_var2
            ))
        }
        if (!(long_data[, typeof(get(treated_subset_var2))] ==
            "logical")) {
            stop(sprintf(
                "Variable treated_subset_var2='%s' must be of type 'logical' in long_data (i.e., only TRUE or FALSE values).",
                treated_subset_var2
            ))
        }
    }
    if (testCharacter(discrete_covars)) {
        for (vv in discrete_covars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable discrete_covars='%s' is not in the long_data you provided.",
                    vv
                ))
            }
        }
    }
    if (testCharacter(cont_covars)) {
        for (vv in cont_covars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable cont_covars='%s' is not in the long_data you provided.",
                    vv
                ))
            }
        }
    }
    if (testCharacter(ref_discrete_covars)) {
        for (vv in ref_discrete_covars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable ref_discrete_covars='%s' is not in the long_data you provided.",
                    vv
                ))
            }
        }
    }
    if (testCharacter(ref_cont_covars)) {
        for (vv in ref_cont_covars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable ref_cont_covars='%s' is not in the long_data you provided.",
                    vv
                ))
            }
        }
    }
    if (testCharacter(reg_weights)) {
        if (!(reg_weights %in% names(long_data))) {
            stop(sprintf(
                "Variable reg_weights='%s' is not in the long_data you provided.",
                reg_weights
            ))
        }
    }
    if (testCharacter(ref_reg_weights)) {
        if (!(ref_reg_weights %in% names(long_data))) {
            stop(sprintf(
                "Variable ref_reg_weights='%s' is not in the long_data you provided.",
                ref_reg_weights
            ))
        }
    }
    if (testCharacter(ntile_var)) {
        if (!(ntile_var %in% names(long_data))) {
            stop(sprintf(
                "Variable ntile_var='%s' is not in the long_data you provided.",
                ntile_var
            ))
        }
    }
    if (testCharacter(endog_var)) {
        if (!(endog_var %in% names(long_data))) {
            stop(sprintf(
                "Variable endog_var='%s' is not in the long_data you provided.",
                endog_var
            ))
        }
    }
    if (linearDiD == TRUE) {
        if (testCharacter(linearDiD_treat_var)) {
            if (!(linearDiD_treat_var %in% names(long_data))) {
                stop(sprintf(
                    "Variable linearDiD_treat_var='%s' is not in the long_data you provided.",
                    linearDiD_treat_var
                ))
            }
        }
    }
    if (testCharacter(ref_event_time_mean_vars)) {
        for (vv in ref_event_time_mean_vars) {
            if (!(vv %in% names(long_data))) {
                stop(sprintf(
                    "Variable ref_event_time_mean_vars='%s' is not in the long_data you provided.",
                    vv
                ))
            }
        }
        # if (outcomevar %in% ref_event_time_mean_vars) {
        #   stop(sprintf("\n Variable outcomevar='%s' was also found among ref_event_time_mean_vars.\n We already grab stacked ES mean of outcomevar, please remove it from ref_event_time_mean_vars.",
        #                vv))
        # }
    }
    if ((!is.null(ntile_var)) & omitted_event_time != ntile_event_time) {
        stop(sprintf(
            "ntile_event_time='%s', but currently code can only accept ntile_event_time == omitted_event_time (e.g., %s).",
            ntile_event_time, omitted_event_time
        ))
    }
    if (calculate_collapse_estimates == TRUE & homogeneous_ATT ==
        TRUE) {
        stop("Cannot have calculate_collapse_estimates == TRUE & homogeneous_ATT == TRUE. Consider setting homogeneous_ATT = FALSE.")
    }
    if (add_unit_fes == TRUE) {
        design_vars <- c(cal_time_var, unit_var)
    } else {
        design_vars <- c(cal_time_var, onset_time_var)
    }
    if (testCharacter(discrete_covars)) {
        for (vv in discrete_covars) {
            if (vv %in% design_vars) {
                stop(sprintf(
                    "Variable discrete_covars='%s' is among c('%s','%s') which are already controlled in the design.",
                    vv, design_vars[[1]], design_vars[[2]]
                ))
            }
        }
    }
    if (testCharacter(cont_covars)) {
        for (vv in cont_covars) {
            if (vv %in% design_vars) {
                stop(sprintf(
                    "Variable cont_covars='%s' is among c('%s','%s') which are already controlled in the design.",
                    vv, design_vars[[1]], design_vars[[2]]
                ))
            }
        }
    }
    if (testCharacter(ref_discrete_covars)) {
        for (vv in ref_discrete_covars) {
            if (vv %in% design_vars) {
                stop(sprintf(
                    "Variable ref_discrete_covars='%s' is among c('%s','%s') which are already controlled in the design.",
                    vv, design_vars[[1]], design_vars[[2]]
                ))
            }
        }
    }
    if (testCharacter(ref_cont_covars)) {
        for (vv in ref_cont_covars) {
            if (vv %in% design_vars) {
                stop(sprintf(
                    "Variable ref_cont_covars='%s' is among c('%s','%s') which are already controlled in the design.",
                    vv, design_vars[[1]], design_vars[[2]]
                ))
            }
        }
    }
    if (!is.null(reg_weights) & !is.null(ref_reg_weights)) {
        stop("Supplied variables for both reg_weights and ref_reg_weights, but ES() only admits using one or the other type of weight (or neither).")
    }
    if (!(never_treat_action %in% c(
        "none", "exclude", "keep",
        "only"
    ))) {
        stop(sprintf(
            "never_treat_action='%s' is not among allowed values (c('none', 'exclude', 'keep', 'only')).",
            never_treat_action
        ))
    }
    if (never_treat_action == "none" & dim(long_data[is.na(get(onset_time_var))])[1] >
        0) {
        stop(sprintf(
            "never_treat_action='%s' but some units have %s=NA. Please edit supplied long_data or consider another option for never_treat_action.",
            never_treat_action, onset_time_var
        ))
    }
    if (never_treat_action != "none" & dim(long_data[is.na(get(onset_time_var))])[1] ==
        0) {
        stop(sprintf(
            "never_treat_action='%s' but no units have %s=NA. Let me suggest never_treat_action='none'.",
            never_treat_action, onset_time_var
        ))
    }
    if (is.null(cluster_vars)) {
        warning(sprintf(
            "Supplied cluster_vars = NULL; given stacking in ES(), standard errors may be too small. Consider cluster_vars='%s' instead.",
            unit_var
        ))
    }
    if (bootstrap_num_cores > 1 & cohort_by_cohort_num_cores ==
        1) {
        if (bootstrap_num_cores > (parallel::detectCores() -
            1)) {
            warning(sprintf(
                "Supplied bootstrap_num_cores='%s'; this exceeds typical system limits and may cause issues.",
                bootstrap_num_cores
            ))
        }
    } else if (bootstrap_num_cores == 1 & cohort_by_cohort_num_cores >
        1) {
        if (cohort_by_cohort_num_cores > (parallel::detectCores() -
            1)) {
            warning(sprintf(
                "Supplied cohort_by_cohort_num_cores='%s'; this exceeds typical system limits and may cause issues.",
                cohort_by_cohort_num_cores
            ))
        }
    } else if (bootstrap_num_cores > 1 & cohort_by_cohort_num_cores >
        1) {
        if (as.integer(bootstrap_num_cores * cohort_by_cohort_num_cores) >
            (parallel::detectCores() - 1)) {
            warning(sprintf(
                "Supplied bootstrap_num_cores='%s' & cohort_by_cohort_num_cores='%s'; the product (%s) exceeds typical system limits and may cause issues.",
                bootstrap_num_cores, cohort_by_cohort_num_cores,
                as.integer(bootstrap_num_cores * cohort_by_cohort_num_cores)
            ))
        }
    }
    if (ipw == TRUE) {
        if (!(ipw_model %in% c("linear", "logit", "probit"))) {
            stop(sprintf(
                "ipw_model='%s' is not among allowed values (c('linear', 'logit', 'probit')).",
                ipw_model
            ))
        }
    }
    if (ipw_ps_lower_bound > ipw_ps_upper_bound) {
        stop(sprintf(
            "ipw_ps_lower_bound='%s' & ipw_ps_upper_bound='%s', which means ipw_ps_lower_bound > ipw_ps_upper_bound. Consider revising these cutoffs.",
            ipw_ps_lower_bound, ipw_ps_upper_bound
        ))
    }
    return(invisible(NULL))
}


ES_clean_data <- function(long_data, outcomevar, unit_var, cal_time_var, onset_time_var,
                          cluster_vars, discrete_covars = NULL, cont_covars = NULL,
                          ref_discrete_covars = NULL, ref_cont_covars = NULL, anticipation = 0,
                          min_control_gap = 1, max_control_gap = Inf, omitted_event_time = -2,
                          treated_subset_var = NA, treated_subset_event_time = NA,
                          control_subset_var = NA, control_subset_event_time = NA,
                          treated_subset_var2 = NA, treated_subset_event_time2 = NA,
                          control_subset_var2 = NA, control_subset_event_time2 = NA,
                          treated_subset_var3 = NA, treated_subset_event_time3 = NA,
                          control_subset_var3 = NA, control_subset_event_time3 = NA,
                          never_treat_action = "none", never_treat_val = NA, reg_weights = NULL,
                          ref_reg_weights = NULL, event_vs_noevent = FALSE, ntile_var = NULL,
                          ntile_event_time = -2, ntiles = NA, ntile_var_value = NA,
                          ntile_avg = FALSE, endog_var = NULL, linearDiD_treat_var = NULL,
                          ref_event_time_mean_vars = NULL,
                          compare_var = NULL, case = "balanced_in_calendar_time") {
    min_eligible_cohort <- long_data[get(cal_time_var) - get(onset_time_var) ==
        omitted_event_time, min(get(onset_time_var))]
    long_data[, `:=`(relevant_subset, as.integer(get(onset_time_var) >=
        min_eligible_cohort))]
    gc()
    onset_times <- long_data[relevant_subset == 1, sort(unique(get(onset_time_var)))]
    cal_times <- long_data[relevant_subset == 1, sort(unique(get(cal_time_var)))]
    min_onset_time <- min(onset_times)
    max_onset_time <- max(onset_times)
    min_cal_time <- min(cal_times)
    max_cal_time <- max(cal_times)
    j <- 0
    stack_across_cohorts_balanced_treated_control <- list()
    if (max_onset_time > max_cal_time & !(never_treat_action %in%
        c("keep", "only"))) {
        last_treat_grp_time <- max_cal_time - (min_control_gap -
            1)
    } else if (max_onset_time > max_cal_time & (never_treat_action %in%
        c("keep", "only"))) {
        last_treat_grp_time <- max_cal_time
    } else if (max_onset_time <= max_cal_time) {
        last_treat_grp_time <- max_onset_time - min_control_gap
    }
    for (event_cohort in intersect(
        min_onset_time:last_treat_grp_time,
        onset_times
    )) {
        count_potential_controls <- dim(long_data[relevant_subset ==
            1 & between(get(onset_time_var), event_cohort +
            min_control_gap, event_cohort + max_control_gap,
        incbounds = TRUE
        )])[1]
        if (count_potential_controls <= 0) {
            flog.info(sprintf(
                "Given min_control_gap='%s' & max_control_gap='%s', no control units found for %s='%s'",
                min_control_gap, max_control_gap, onset_time_var,
                event_cohort
            ))
        } else {
            j <- j + 1
            possible_treated_control <- list()
            possible_treated_control[[1]] <- long_data[relevant_subset ==
                1 & get(onset_time_var) == event_cohort, unique(na.omit(c(
                outcomevar,
                unit_var, cal_time_var, onset_time_var, treated_subset_var,
                control_subset_var, treated_subset_var2, control_subset_var2, treated_subset_var3, control_subset_var3,
                cluster_vars, discrete_covars, cont_covars,
                reg_weights, ref_reg_weights, ref_discrete_covars,
                ref_cont_covars, ntile_var, endog_var, linearDiD_treat_var,
                ref_event_time_mean_vars, compare_var
            ))), with = FALSE]
            gc()
            possible_treated_control[[1]][, `:=`(
                ref_onset_time,
                event_cohort
            )]
            possible_treated_control[[1]][, `:=`(treated, 1)]
            possible_treated_control[[2]] <- long_data[relevant_subset ==
                1 & between(get(onset_time_var), event_cohort +
                min_control_gap, event_cohort + max_control_gap,
            incbounds = TRUE
            ), unique(na.omit(c(
                outcomevar,
                unit_var, cal_time_var, onset_time_var, treated_subset_var,
                control_subset_var, treated_subset_var2, control_subset_var2, treated_subset_var3, control_subset_var3,
                cluster_vars, discrete_covars, cont_covars,
                reg_weights, ref_reg_weights, ref_discrete_covars,
                ref_cont_covars, ntile_var, endog_var, linearDiD_treat_var,
                ref_event_time_mean_vars, compare_var
            ))), with = FALSE]
            if (never_treat_action == "only") {
                possible_treated_control[[2]] <- possible_treated_control[[2]][get(onset_time_var) ==
                    never_treat_val]
            }
            gc()
            possible_treated_control[[2]][, `:=`(
                ref_onset_time,
                event_cohort
            )]
            possible_treated_control[[2]][, `:=`(treated, 0)]
            possible_treated_control <- rbindlist(possible_treated_control,
                use.names = TRUE
            )
            gc()
            possible_treated_control[, `:=`(
                ref_event_time,
                get(cal_time_var) - ref_onset_time
            )]
            if (event_vs_noevent == FALSE) {
                possible_treated_control <- possible_treated_control[(treated ==
                    1) | ((treated == 0) & (get(cal_time_var) <
                    get(onset_time_var) - anticipation))]
            }
            max_control_year <- possible_treated_control[treated ==
                0, max(get(cal_time_var))]
            possible_treated_control <- possible_treated_control[get(cal_time_var) <=
                max_control_year]
            gc()
            i <- 0
            balanced_treated_control <- list()
            temp <- possible_treated_control[, .N, by = "ref_event_time"]
            years <- sort(unique(temp$ref_event_time))
            temp <- NULL
            gc()
            if (!(is.null(ntile_var)) & ntile_avg == TRUE) {
                possible_treated_control[, `:=`(pre1, sum2((ref_event_time ==
                    omitted_event_time) * get(ntile_var), na.rm = TRUE)),
                by = list(get(unit_var))
                ]
                if ((omitted_event_time + anticipation < -1)) {
                    possible_treated_control[, `:=`(pre2, sum2((ref_event_time ==
                        (omitted_event_time + 1)) * get(ntile_var),
                    na.rm = TRUE
                    )), by = list(get(unit_var))]
                } else if ((omitted_event_time + anticipation ==
                    -1)) {
                    possible_treated_control[, `:=`(pre2, sum2((ref_event_time ==
                        (omitted_event_time - 1)) * get(ntile_var),
                    na.rm = TRUE
                    )), by = list(get(unit_var))]
                }
                possible_treated_control[, `:=`(
                    avg_ntile_var,
                    ((pre1 + pre2) / 2)
                )]
                possible_treated_control[, `:=`(
                    (ntile_var),
                    avg_ntile_var
                )]
                possible_treated_control[, `:=`(
                    avg_ntile_var,
                    NULL
                )]
                gc()
            }
            for (time_since_event in setdiff(years, omitted_event_time)) {
                i <- i + 1
                if (time_since_event < 1 | event_vs_noevent ==
                    TRUE) {
                    balanced_treated_control[[i]] <- possible_treated_control[(get(onset_time_var) ==
                        event_cohort & ref_event_time %in% c(
                        omitted_event_time,
                        time_since_event
                    )) | (get(onset_time_var) >
                        event_cohort & ref_event_time %in% c(
                        omitted_event_time,
                        time_since_event
                    ))]
                    gc()
                } else if (time_since_event >= 1) {
                    balanced_treated_control[[i]] <- possible_treated_control[(get(onset_time_var) ==
                        event_cohort & ref_event_time %in% c(
                        omitted_event_time,
                        time_since_event
                    )) | (get(onset_time_var) >
                        (event_cohort + time_since_event) & ref_event_time %in%
                        c(omitted_event_time, time_since_event))]
                    gc()
                }
                balanced_treated_control[[i]] <- na.omit(balanced_treated_control[[i]],
                    cols = unique(na.omit(c(
                        outcomevar, unit_var,
                        cal_time_var, onset_time_var, cluster_vars,
                        discrete_covars, cont_covars, reg_weights,
                        ntile_var, endog_var, linearDiD_treat_var,
                        "ref_onset_time", "ref_event_time", "treated"
                    )))
                )
                gc()
                balanced_treated_control[[i]][, `:=`(
                    catt_specific_sample,
                    i
                )]
                if (!(is.null(ntile_var))) {
                    q <- quantile(balanced_treated_control[[i]][treated ==
                        1 & ref_event_time == ntile_event_time][[ntile_var]],
                    probs = (seq(1, (ntiles - 1), by = 1) / ntiles)
                    )
                    gc()
                    balanced_treated_control[[i]][, `:=`(
                        relevant_level,
                        sum(get(ntile_var) * (ref_event_time ==
                            ntile_event_time), na.rm = TRUE)
                    ), by = list(get(unit_var))]
                    balanced_treated_control[[i]][relevant_level >
                        q[(ntiles - 1)], `:=`(catt_ntile, ntiles)]
                    for (qt in ((ntiles - 1):1)) {
                        balanced_treated_control[[i]][relevant_level <=
                            q[qt], `:=`(catt_ntile, qt)]
                    }
                    if (!(is.na(ntile_var_value))) {
                        balanced_treated_control[[i]] <- balanced_treated_control[[i]][catt_ntile ==
                            ntile_var_value]
                        balanced_treated_control[[i]][, `:=`(
                            catt_ntile,
                            NULL
                        )]
                    } else {
                        catt_ntile_var <- "catt_ntile"
                    }
                    balanced_treated_control[[i]][, `:=`(
                        relevant_level,
                        NULL
                    )]
                    rm(q)
                    gc()
                }

                if (!(exists("catt_ntile_var"))) {
                    catt_ntile_var <- NULL
                } # so the column selection later doesn't break
            }
            possible_treated_control <- NULL
            gc()
            balanced_treated_control <- rbindlist(balanced_treated_control,
                use.names = TRUE
            )
            balanced_treated_control <- na.omit(balanced_treated_control,
                cols = unique(na.omit(c(
                    outcomevar, unit_var,
                    cal_time_var, onset_time_var, cluster_vars,
                    discrete_covars, cont_covars, reg_weights,
                    ntile_var, endog_var, linearDiD_treat_var,
                    "ref_onset_time", "ref_event_time", "catt_specific_sample",
                    "treated"
                )))
            )
            gc()
            stack_across_cohorts_balanced_treated_control[[j]] <- balanced_treated_control[,
                unique(na.omit(c(
                    outcomevar, unit_var, cal_time_var,
                    onset_time_var, treated_subset_var, control_subset_var,
                    treated_subset_var2, control_subset_var2, treated_subset_var3, control_subset_var3,
                    cluster_vars, discrete_covars, cont_covars,
                    reg_weights, ref_reg_weights, ref_discrete_covars,
                    ref_cont_covars, ntile_var, endog_var, linearDiD_treat_var,
                    ref_event_time_mean_vars, compare_var, catt_ntile_var, "ref_onset_time",
                    "ref_event_time", "catt_specific_sample",
                    "treated"
                ))),
                with = FALSE
            ]
            gc()
            balanced_treated_control <- NULL
            gc()
        }
    }
    stack_across_cohorts_balanced_treated_control <- rbindlist(stack_across_cohorts_balanced_treated_control,
        use.names = TRUE
    )
    gc()
    if (!is.na(treated_subset_var) & is.na(control_subset_var)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            as.integer(max2((get(treated_subset_var) * (ref_event_time ==
                treated_subset_event_time)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[treated ==
            0 | valid_treated_group == 1]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            NULL
        )]
        gc()
    }
    if (!is.na(control_subset_var) & is.na(treated_subset_var)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            as.integer(max2((get(control_subset_var) * (ref_event_time ==
                control_subset_event_time)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[valid_control_group ==
            1 | treated == 1]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            NULL
        )]
        gc()
    }
    if (!is.na(control_subset_var) & !is.na(treated_subset_var)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            as.integer(max2((get(treated_subset_var) * (ref_event_time ==
                treated_subset_event_time)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            as.integer(max2((get(control_subset_var) * (ref_event_time ==
                control_subset_event_time)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[(treated ==
            0 & valid_control_group == 1) | (treated == 1 &
            valid_treated_group == 1)]
        stack_across_cohorts_balanced_treated_control[, `:=`(c(
            "valid_treated_group",
            "valid_control_group"
        ), NULL)]
        gc()
    }
    if (!is.na(treated_subset_var2) & is.na(control_subset_var2)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            as.integer(max2((get(treated_subset_var2) * (ref_event_time ==
                treated_subset_event_time2)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[treated ==
            0 | valid_treated_group == 1]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            NULL
        )]
        gc()
    }
    if (!is.na(control_subset_var2) & is.na(treated_subset_var2)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            as.integer(max2((get(control_subset_var2) * (ref_event_time ==
                control_subset_event_time2)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[valid_control_group ==
            1 | treated == 1]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            NULL
        )]
        gc()
    }
    if (!is.na(control_subset_var2) & !is.na(treated_subset_var2)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            as.integer(max2((get(treated_subset_var2) * (ref_event_time ==
                treated_subset_event_time2)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            as.integer(max2((get(control_subset_var2) * (ref_event_time ==
                control_subset_event_time2)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[(treated ==
            0 & valid_control_group == 1) | (treated == 1 &
            valid_treated_group == 1)]
        stack_across_cohorts_balanced_treated_control[, `:=`(c(
            "valid_treated_group",
            "valid_control_group"
        ), NULL)]
        gc()
    }
    if (!is.na(control_subset_var3) & !is.na(treated_subset_var3)) {
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_treated_group,
            as.integer(max2((get(treated_subset_var3) * (ref_event_time ==
                treated_subset_event_time3)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control[, `:=`(
            valid_control_group,
            as.integer(max2((get(control_subset_var3) * (ref_event_time ==
                control_subset_event_time3)), na.rm = TRUE))
        ),
        by = c(unit_var, "ref_onset_time")
        ]
        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[(treated ==
            0 & valid_control_group == 1) | (treated == 1 &
            valid_treated_group == 1)]
        stack_across_cohorts_balanced_treated_control[, `:=`(c(
            "valid_treated_group",
            "valid_control_group"
        ), NULL)]
        gc()
    }
    if (never_treat_action == "only") {
        rm(never_treat_val)
    }
    long_data[, `:=`(relevant_subset, NULL)]
    gc()

    if (case == "balanced_in_event_time") {

        # In this case, we require that the panel has been artificially balanced (with NAs) before
        # hitting this function

        # In general, we just want to remove any catt_specific_samples where
        # there aren't both pre- and post-treatment observations available for
        # both the treatment and control groups

        stack_across_cohorts_balanced_treated_control[, to_keep_imbens_etal := as.integer(max(ref_event_time) != min(ref_event_time)),
            by = .(ref_onset_time, catt_specific_sample, treated)
        ]

        stack_across_cohorts_balanced_treated_control[, to_keep_imbens_etal := min(to_keep_imbens_etal),
            by = .(ref_onset_time, catt_specific_sample)
        ]

        stack_across_cohorts_balanced_treated_control <- stack_across_cohorts_balanced_treated_control[to_keep_imbens_etal == 1]
        stack_across_cohorts_balanced_treated_control[, to_keep_imbens_etal := NULL]
        gc()
    }

    flog.info(sprintf(
        "Successfully produced a stacked dataset with %s rows.",
        format(dim(stack_across_cohorts_balanced_treated_control)[1],
            scientific = FALSE, big.mark = ","
        )
    ))
    return(stack_across_cohorts_balanced_treated_control)
}

makeRatioVcov <- function(coefs_fs, coefs_rf, vcov_fs, vcov_rf) {

    # As we don't have cross-equation correlations, we take the approach as in, for example:
    ## Dee and Evans (2003), Pacini and Windmeijer (2016), and Sampat and Williams (2018)
    ## e.g., treating the cross-equation covariance of residuals as 0

    # In this case, variance of ratio (using a first-order Taylor expansion and the delta method):
    # (vcov_rf[i,i] + (coefs_rf[i.i]/coefs_fs[i,i])^2 * vcov_fs[i,i]) / (coefs_fs[i,i])^2

    # coefs_fs = copy(catt_coefs_fs);
    # coefs_rf = copy(catt_coefs_rf);
    # vcov_fs = copy(catt_vcov_fs);
    # vcov_rf = copy(catt_vcov_rf)

    # make a diagonal matrix of the first stage inverse
    C_hat <- solve(as.matrix(diag(x = coefs_fs, nrow = length(coefs_fs))))
    # include nrow in diag()  as can have inconsistent behavior without it

    # make a diagonal matrix of the RF variance-covariance matrix
    # use notation to match Pacini and Windmeijer (2016): Var_pi_y1
    rf_vcov_diag <- diag(vcov_rf)
    Var_pi_y1 <- as.matrix(diag(x = rf_vcov_diag, nrow = length(rf_vcov_diag)))

    # make a diagonal matrix of the FS variance-covariance matrix
    # use notation to match Pacini and Windmeijer (2016): Var_pi_x2
    fs_vcov_diag <- diag(vcov_fs)
    Var_pi_x2 <- as.matrix(diag(x = fs_vcov_diag, nrow = length(fs_vcov_diag)))

    # make a ratio product term
    # use notation to match Pacini and Windmeijer (2016): (beta_ts2sls' %x% C_hat)
    # in our case, C_hat is diagonal, so can just do element-wise multiplication
    beta_ts2sls <- coefs_rf / coefs_fs
    ratio_prod <- beta_ts2sls * C_hat

    # eqn 12 in Pacini and Windmeijer (2016)
    # recall: in R, matA * matB is done as t(A) * B
    # identical( t(C_hat) %*% Var_pi_y1, crossprod(C_hat, Var_pi_y1) )
    first_quad_form <- crossprod(crossprod(C_hat, Var_pi_y1), t(C_hat)) # C_hat * Var_pi_y1 * C_hat'
    second_quad_form <- crossprod(crossprod(ratio_prod, Var_pi_x2), t(ratio_prod)) # (beta_ts2sls' %x% C_hat) * Var_pi_x2 * (beta_ts2sls' %x% C_hat)'
    beta_ts2sls_vcov <- first_quad_form + second_quad_form

    return(beta_ts2sls_vcov)
}

Wald_ES2 <- function(long_data, outcomevar, unit_var, cal_time_var, onset_time_var,
                     cluster_vars, omitted_event_time = -2, anticipation = 0,
                     min_control_gap = 1, max_control_gap = Inf, linearize_pretrends = FALSE,
                     fill_zeros = FALSE, residualize_covariates = FALSE, control_subset_var = NA,
                     control_subset_event_time = 0, treated_subset_var = NA,
                     treated_subset_event_time = 0, control_subset_var2 = NA,
                     control_subset_event_time2 = 0, treated_subset_var2 = NA,
                     treated_subset_event_time2 = 0, discrete_covars = NULL,
                     cont_covars = NULL, never_treat_action = "none", homogeneous_ATT = FALSE,
                     reg_weights = NULL, add_unit_fes = FALSE, bootstrapES = FALSE,
                     bootstrap_iters = 1, bootstrap_num_cores = 1, keep_all_bootstrap_results = FALSE,
                     ipw = FALSE, ipw_model = "linear", ipw_composition_change = FALSE,
                     ipw_keep_data = FALSE, ipw_ps_lower_bound = 0, ipw_ps_upper_bound = 1,
                     event_vs_noevent = FALSE, ref_discrete_covars = NULL, ref_discrete_covar_event_time = 0,
                     ref_cont_covars = NULL, ref_cont_covar_event_time = 0, calculate_collapse_estimates = FALSE,
                     collapse_inputs = NULL, ref_reg_weights = NULL, ref_reg_weights_event_time = 0,
                     ntile_var = NULL, ntile_event_time = -2, ntiles = NA, ntile_var_value = NA,
                     ntile_avg = FALSE, endog_var = NULL, cohort_by_cohort = FALSE,
                     cohort_by_cohort_num_cores = 1, heterogeneous_only = TRUE,
                     ref_event_time_mean_vars = NULL,
                     ref_event_time_mean_var_event_time1 = NULL,
                     ref_event_time_mean_var_event_time2 = NULL,
                     ref_event_time_mean_var_event_time3 = NULL,
                     control_subset_var3 = NA,
                     control_subset_event_time3 = 0, treated_subset_var3 = NA,
                     treated_subset_event_time3 = 0,
                     compare_var = NULL,
                     compare_var_condition = "equal",
                     compare_var_initial_event_time = 0,
                     compare_var_final_event_time = 0,
                     compare_var_eq_tol = 0,
                     outcomevar_temp = NULL,
                     endog_ntiles = NA,
                     endog_ntile_value = NA,
                     return_clean_data = FALSE,
                     case = NULL) {
    flog.info("Beginning Wald_ES.")
    ES_check_inputs(
        long_data = long_data, outcomevar = outcomevar,
        unit_var = unit_var, cal_time_var = cal_time_var, onset_time_var = onset_time_var,
        cluster_vars = cluster_vars, omitted_event_time = omitted_event_time,
        anticipation = anticipation, min_control_gap = min_control_gap,
        max_control_gap = max_control_gap, linearize_pretrends = linearize_pretrends,
        fill_zeros = fill_zeros, residualize_covariates = residualize_covariates,
        control_subset_var = control_subset_var, control_subset_event_time = control_subset_event_time,
        treated_subset_var = treated_subset_var, treated_subset_event_time = treated_subset_event_time,
        control_subset_var2 = control_subset_var2, control_subset_event_time2 = control_subset_event_time2,
        treated_subset_var2 = treated_subset_var2, treated_subset_event_time2 = treated_subset_event_time2,
        discrete_covars = discrete_covars, cont_covars = cont_covars,
        never_treat_action = never_treat_action, homogeneous_ATT = homogeneous_ATT,
        reg_weights = reg_weights, add_unit_fes = add_unit_fes,
        bootstrapES = bootstrapES, bootstrap_iters = bootstrap_iters,
        bootstrap_num_cores = bootstrap_num_cores, keep_all_bootstrap_results,
        ipw = ipw, ipw_model = ipw_model, ipw_composition_change = ipw_composition_change,
        ipw_keep_data = ipw_keep_data, ipw_ps_lower_bound = ipw_ps_lower_bound,
        ipw_ps_upper_bound = ipw_ps_upper_bound, event_vs_noevent = event_vs_noevent,
        ref_discrete_covars = ref_discrete_covars, ref_discrete_covar_event_time = ref_discrete_covar_event_time,
        ref_cont_covars = ref_cont_covars, ref_cont_covar_event_time = ref_cont_covar_event_time,
        calculate_collapse_estimates = calculate_collapse_estimates,
        collapse_inputs = collapse_inputs, ref_reg_weights = ref_reg_weights,
        ref_reg_weights_event_time = ref_reg_weights_event_time,
        ntile_var = ntile_var, ntile_event_time = ntile_event_time,
        ntiles = ntiles, ntile_var_value = ntile_var_value,
        ntile_avg = ntile_avg, endog_var = endog_var, cohort_by_cohort = cohort_by_cohort,
        cohort_by_cohort_num_cores = cohort_by_cohort_num_cores,
        heterogeneous_only = heterogeneous_only, ref_event_time_mean_vars = ref_event_time_mean_vars
    )
    if (bootstrapES == TRUE) {
        original_sample <- copy(long_data)
    }
    if (never_treat_action == "exclude") {
        never_treat_val <- NA
        long_data <- long_data[!is.na(get(onset_time_var))]
        gc()
    } else if (never_treat_action %in% c("keep", "only")) {
        never_treat_val <- max(max(long_data[[onset_time_var]],
            na.rm = TRUE
        ), max(long_data[[cal_time_var]], na.rm = TRUE)) +
            min_control_gap + anticipation + 1
        long_data[is.na(get(onset_time_var)), `:=`(
            (onset_time_var),
            never_treat_val
        )]
    }
    if (fill_zeros) {
        flog.info("Filling in zeros.")
        long_data <- ES_expand_to_balance(
            long_data = long_data,
            vars_to_fill = na.omit(unique(c(outcomevar, endog_var))),
            unit_var = unit_var, cal_time_var = cal_time_var,
            onset_time_var = onset_time_var
        )
    }
    if (is.infinite(suppressWarnings(long_data[get(cal_time_var) -
        get(onset_time_var) == omitted_event_time, min(get(onset_time_var))]))) {
        stop(sprintf(
            "Variable onset_time_var='%s' has no treated groups with observations at pre-treatment event time %s.",
            onset_time_var, omitted_event_time
        ))
    }
    if (!(is.null(reg_weights))) {
        count_long_data_initial <- dim(long_data)[1]
        count_reg_weights_isna <- dim(long_data[is.na(get(reg_weights))])[1]
        count_reg_weights_isinf <- dim(long_data[is.infinite(get(reg_weights))])[1]
        count_reg_weights_nonpositive <- dim(long_data[get(reg_weights) <=
            0])[1]
        long_data <- long_data[(!is.na(get(reg_weights))) &
            (!is.infinite(get(reg_weights))) & (get(reg_weights) >
            0)]
        gc()
        count_dropped <- (count_long_data_initial - dim(long_data)[1])
        rm(count_long_data_initial)
        if (count_dropped != 0) {
            flog.info((sprintf(
                "\n Warning: Droppped %s from long_data due to missing or extreme reg_weights. \n Of those dropped, the breakdown is: \n 1) %s%% had NA reg_weights \n 2) %s%% had Inf/-Inf reg_weights \n 3) %s%% had reg_weights <= 0",
                format(count_dropped, scientific = FALSE, big.mark = ","),
                round(((count_reg_weights_isna / count_dropped) *
                    100), digits = 4), round(((count_reg_weights_isinf / count_dropped) *
                    100), digits = 4), round(((count_reg_weights_nonpositive / count_dropped) *
                    100), digits = 4)
            )))
        }
        rm(
            count_reg_weights_isna, count_reg_weights_isinf,
            count_reg_weights_nonpositive, count_dropped
        )
        gc()
    }
    if (linearize_pretrends) {
        flog.info("NOT CURRENTLY Linearizing pre-trends.")
    }
    if (residualize_covariates == TRUE) {
        flog.info("NOT CURRENTLY Residualizing on covariates.")
    }
    flog.info("Beginning data stacking.")
    ES_data <- ES_clean_data(
        long_data = long_data, outcomevar = outcomevar,
        unit_var = unit_var, cal_time_var = cal_time_var, onset_time_var = onset_time_var,
        anticipation = anticipation, min_control_gap = min_control_gap,
        max_control_gap = max_control_gap, omitted_event_time = omitted_event_time,
        control_subset_var = control_subset_var, control_subset_event_time = control_subset_event_time,
        treated_subset_var = treated_subset_var, treated_subset_event_time = treated_subset_event_time,
        control_subset_var2 = control_subset_var2, control_subset_event_time2 = control_subset_event_time2,
        treated_subset_var2 = treated_subset_var2, treated_subset_event_time2 = treated_subset_event_time2,
        treated_subset_var3 = treated_subset_var3, treated_subset_event_time3 = treated_subset_event_time3,
        control_subset_var3 = control_subset_var3, control_subset_event_time3 = control_subset_event_time3,
        never_treat_action = never_treat_action, never_treat_val = never_treat_val,
        cluster_vars = cluster_vars, discrete_covars = unique(discrete_covars, "mtr_observed"),
        cont_covars = cont_covars, reg_weights = reg_weights,
        event_vs_noevent = event_vs_noevent, ref_discrete_covars = ref_discrete_covars,
        ref_cont_covars = ref_cont_covars, ref_reg_weights = ref_reg_weights,
        ntile_var = ntile_var, ntile_event_time = ntile_event_time,
        ntiles = ntiles, ntile_var_value = ntile_var_value,
        ntile_avg = ntile_avg, endog_var = endog_var, ref_event_time_mean_vars = ref_event_time_mean_vars,
        compare_var = compare_var, case = case
    )
    print(sprintf("Mean of the outcomvar in the stacked data: %s", print(summary(ES_data[[outcomevar]]))))

    # Now grab lower thresholds (if ntile_var supplied)
    if (!(is.null(ntile_var))) {
        bar_y <- ES_data[treated == 1 & ref_event_time == omitted_event_time, min(get(ntile_var)), by = list(ref_onset_time, catt_specific_sample)]
    }

    # Now break the data along quantiles of the endogenous variable, if relevant
    if (!is.na(endog_ntiles)) {
        if (endog_ntiles != 1) {
            # Split the data along the 'endog_var' according to 'endog_ntiles', and keep the ntile matching 'endog_ntile_value'
            time_invariant <- ES_data[, c(unit_var, endog_var), with = FALSE]
            gc()

            # time_invariant will have duplicates for future winners appearing multiple times
            time_invariant[, unit_id := seq_len(.N), by = unit_var]
            time_invariant <- time_invariant[unit_id == 1]
            time_invariant[, unit_id := NULL]
            gc()

            q <- quantile(time_invariant[[endog_var]],
                probs = (seq(1, (endog_ntiles - 1), by = 1) / endog_ntiles)
            )
            gc()

            time_invariant[get(endog_var) > q[(endog_ntiles - 1)], endog_ntile := endog_ntiles]
            for (qt in ((endog_ntiles - 1):1)) {
                time_invariant[get(endog_var) <= q[[qt]], endog_ntile := qt]
            }
            rm(qt)
            print(summary(time_invariant))
            flog.info(q)

            # Restrict to case-specific endog_ntile_value
            time_invariant <- time_invariant[endog_ntile == endog_ntile_value, c(unit_var), with = FALSE]
            ES_data <- merge(ES_data, time_invariant, by = unit_var, sort = FALSE)

            rm(time_invariant, q)
            gc()
        }

        if (grepl("multiperiod", endog_var)) {
            ES_data[get(cal_time_var) < get(onset_time_var), (endog_var) := 0]
        }

        flog.info(sprintf(
            "Successfully produced a stacked dataset with %s rows after restricting to ntile=%s (of %s ntiles) of endog_var='%s'",
            format(dim(ES_data)[1],
                scientific = FALSE, big.mark = ","
            ), endog_ntile_value, endog_ntiles, endog_var
        ))
    }

    # code to make the ref_onset_time based outcome variables
    if (!(is.null(outcomevar_temp))) {
        if (outcomevar_temp %in% c(
            "divorce",
            "divorce_ch",
            "new_divorce_hazard",
            "new_divorce_hazard_ch",
            "new_marriage",
            "new_marriage_ch",
            "new_marriage_hazard",
            "new_marriage_hazard_ch"
        )) {
            outcomevar <- copy(outcomevar_temp)
            flog.info(outcomevar)

            # Will construct a version comparing each "post" year to omitted_event_time
            # and then a "hazard" version

            if (outcomevar_temp %in% c(
                "divorce",
                "divorce_ch",
                "new_marriage",
                "new_marriage_ch"
            )) {
                ES_data[ref_event_time == -2, base_hhkey := hh_key]
                ES_data[, base_hhkey := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_hhkey", by = c(unit_var, "ref_onset_time")]

                ES_data[, (outcomevar) := as.integer(base_hhkey != hh_key)]
            }

            if (outcomevar_temp %in% c(
                "new_divorce_hazard",
                "new_divorce_hazard_ch",
                "new_marriage_hazard",
                "new_marriage_hazard_ch"
            )) {
                ES_data[ref_event_time == -2, base_hhkey := hh_key]
                ES_data[, base_hhkey := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_hhkey", by = c(unit_var, "ref_onset_time")]

                ES_data[, (outcomevar) := as.integer(base_hhkey != hh_key)]
                ES_data[ref_event_time < omitted_event_time, (outcomevar) := 0]

                # Identify the first post-treatment ref_event_time with a marital status change (if any)
                # Will be specific to a unit_var and ref_onset_time

                ES_data[, min_year_change := get(outcomevar) * (ref_onset_time + ref_event_time)]
                ES_data[ref_event_time < omitted_event_time | min_year_change == 0, min_year_change := 9999]
                ES_data[, min_year_change := min(min_year_change), by = c(unit_var, "ref_onset_time")]
                ES_data[min_year_change == 9999, (outcomevar) := 0]
                ES_data[min_year_change != 9999, min_year_change := (min_year_change - ref_onset_time)]
                ES_data[ref_event_time >= min_year_change, (outcomevar) := 1]

                ES_data[, c("min_year_change") := NULL]
            }

            ES_data[, base_hhkey := NULL]
        }
    }

    # code to make job-to-job outcomes relative to omitted_event_time

    if (outcomevar %in% c("firm_move")) {

        # Need to define these outcomes
        # For all, will be 0 in -2

        ES_data[, (outcomevar) := NULL]

        ES_data[ref_event_time == -2, base_firm := payer_tin]
        ES_data[, base_firm := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_firm", by = c(unit_var, "ref_onset_time")]

        ES_data[, (outcomevar) := as.integer((base_firm != payer_tin) & (main_job_W2 == 1))]
    }

    if (outcomevar %in% c("firm_move_up")) {

        # Need to define these outcomes
        # For all, will be 0 in -2

        ES_data[, (outcomevar) := NULL]

        ES_data[ref_event_time == -2, base_firm := payer_tin]
        ES_data[, base_firm := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_firm", by = c(unit_var, "ref_onset_time")]

        ES_data[ref_event_time == -2, base_wage := mean_weighted_resid]
        ES_data[, base_wage := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_wage", by = c(unit_var, "ref_onset_time")]

        ES_data[, (outcomevar) := as.integer((base_firm != payer_tin) & (main_job_W2 == 1) & (mean_weighted_resid >= base_wage))]
    }

    if (outcomevar %in% c("firm_move_down")) {

        # Need to define these outcomes
        # For all, will be 0 in -2

        ES_data[, (outcomevar) := NULL]

        ES_data[ref_event_time == -2, base_firm := payer_tin]
        ES_data[, base_firm := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_firm", by = c(unit_var, "ref_onset_time")]

        ES_data[ref_event_time == -2, base_wage := mean_weighted_resid]
        ES_data[, base_wage := lapply(.SD, function(x) na.omit(x)[1]), .SDcols = "base_wage", by = c(unit_var, "ref_onset_time")]

        ES_data[, (outcomevar) := as.integer((base_firm != payer_tin) & (main_job_W2 == 1) & (mean_weighted_resid < base_wage))]
    }

    # compare_var code
    if (!is.null(compare_var)) {

        # Record the variable at compare_var_initial_event_time
        ES_data[, compare_var_initial := get(compare_var)]
        ES_data[
            ref_event_time != compare_var_initial_event_time,
            `:=`(compare_var_initial, NA)
        ]
        if (compare_var_initial_event_time == omitted_event_time) {
            if (class(ES_data[[compare_var]]) == "integer") {
                ES_data[, `:=`(compare_var_initial, as.integer(sum2(compare_var_initial *
                    (ref_event_time == compare_var_initial_event_time),
                na.rm = TRUE
                ))), by = c(
                    unit_var, "ref_onset_time",
                    "catt_specific_sample"
                )]
            } else if (class(ES_data[[compare_var]]) == "numeric") {
                ES_data[, `:=`(compare_var_initial, as.numeric(sum2(compare_var_initial *
                    (ref_event_time == compare_var_initial_event_time),
                na.rm = TRUE
                ))), by = c(
                    unit_var, "ref_onset_time",
                    "catt_specific_sample"
                )]
            }
        } else {
            if (class(ES_data[[compare_var]]) == "integer") {
                ES_data[, `:=`(compare_var_initial, as.integer(sum2(compare_var_initial *
                    (ref_event_time == compare_var_initial_event_time),
                na.rm = TRUE
                ))), by = c(unit_var, "ref_onset_time")]
            } else if (class(ES_data[[compare_var]]) == "numeric") {
                ES_data[, `:=`(compare_var_initial, as.numeric(sum2(compare_var_initial *
                    (ref_event_time == compare_var_initial_event_time),
                na.rm = TRUE
                ))), by = c(unit_var, "ref_onset_time")]
            }
        }

        # Record the variable at compare_var_final_event_time
        ES_data[, compare_var_final := get(compare_var)]
        ES_data[
            ref_event_time != compare_var_final_event_time,
            `:=`(compare_var_final, NA)
        ]
        if (compare_var_final_event_time == omitted_event_time) {
            if (class(ES_data[[compare_var]]) == "integer") {
                ES_data[, `:=`(compare_var_final, as.integer(sum2(compare_var_final *
                    (ref_event_time == compare_var_final_event_time),
                na.rm = TRUE
                ))), by = c(
                    unit_var, "ref_onset_time",
                    "catt_specific_sample"
                )]
            } else if (class(ES_data[[compare_var]]) == "numeric") {
                ES_data[, `:=`(compare_var_final, as.numeric(sum2(compare_var_final *
                    (ref_event_time == compare_var_final_event_time),
                na.rm = TRUE
                ))), by = c(
                    unit_var, "ref_onset_time",
                    "catt_specific_sample"
                )]
            }
        } else {
            if (class(ES_data[[compare_var]]) == "integer") {
                ES_data[, `:=`(compare_var_final, as.integer(sum2(compare_var_final *
                    (ref_event_time == compare_var_final_event_time),
                na.rm = TRUE
                ))), by = c(unit_var, "ref_onset_time")]
            } else if (class(ES_data[[compare_var]]) == "numeric") {
                ES_data[, `:=`(compare_var_final, as.numeric(sum2(compare_var_final *
                    (ref_event_time == compare_var_final_event_time),
                na.rm = TRUE
                ))), by = c(unit_var, "ref_onset_time")]
            }
        }

        if (compare_var_condition == "equal") {
            ES_data <- ES_data[abs(compare_var_initial - compare_var_final) <= compare_var_eq_tol] # NOTE: this will exclude cases where both are NA
            gc()
        } else if (compare_var_condition == "unequal") {
            ES_data <- ES_data[abs(compare_var_initial - compare_var_final) > compare_var_eq_tol] # NOTE: this will exclude cases where either/both are NA
            gc()
        }

        ES_data[, c("compare_var_initial", "compare_var_final") := NULL]
        gc()

        flog.info(sprintf(
            "Updated size of stacked dataset due to compare_var='%s' with condition '%s': %s rows.",
            compare_var, compare_var_condition, format(dim(ES_data)[1],
                scientific = FALSE, big.mark = ","
            )
        ))
    }
    if (!is.null(ref_discrete_covars)) {
        start_cols <- copy(colnames(ES_data))
        for (var in ref_discrete_covars) {
            if (var %in% discrete_covars) {
                varname <- sprintf("%s_dyn", var)
                if (class(ES_data[[var]]) == "character") {
                    ES_data[, `:=`(sortorder, .I)]
                    ES_data[, `:=`((varname), .GRP), keyby = var]
                    ES_data[
                        (is.na(get(var)) | get(var) == ""),
                        `:=`((varname), NA)
                    ]
                    setorderv(ES_data, "sortorder")
                    ES_data[, `:=`(sortorder, NULL)]
                } else {
                    ES_data[, `:=`((varname), get(var))]
                }
            } else {
                varname <- var
                if (class(ES_data[[var]]) == "character") {
                    ES_data[, `:=`(sortorder, .I)]
                    ES_data[, `:=`(tempvar, .GRP), keyby = var]
                    ES_data[
                        (is.na(get(var)) | get(var) == ""),
                        `:=`(tempvar, NA)
                    ]
                    ES_data[, `:=`((varname), NULL)]
                    setnames(ES_data, "tempvar", varname)
                    setorderv(ES_data, "sortorder")
                    ES_data[, `:=`(sortorder, NULL)]
                }
            }
            ES_data[
                ref_event_time != ref_discrete_covar_event_time,
                `:=`((varname), NA)
            ]
            if (ref_discrete_covar_event_time == omitted_event_time) {
                ES_data[, `:=`((varname), as.integer(sum2(get(varname) *
                    (ref_event_time == ref_discrete_covar_event_time),
                na.rm = TRUE
                ))), by = c(
                    unit_var, "ref_onset_time",
                    "catt_specific_sample"
                )]
            } else {
                ES_data[, `:=`((varname), as.integer(sum2(get(varname) *
                    (ref_event_time == ref_discrete_covar_event_time),
                na.rm = TRUE
                ))), by = c(unit_var, "ref_onset_time")]
            }
        }
        ref_discrete_covars <- unique(na.omit(c(
            ref_discrete_covars,
            setdiff(colnames(ES_data), start_cols)
        )))
        rm(start_cols, var, varname)
    }
    if (!is.null(ref_cont_covars)) {
        start_cols <- copy(colnames(ES_data))
        for (var in ref_cont_covars) {
            if (var %in% cont_covars) {
                varname <- sprintf("%s_dyn", var)
                ES_data[, `:=`((varname), get(var))]
            } else {
                varname <- var
            }
            ES_data[
                ref_event_time != ref_cont_covar_event_time,
                `:=`((varname), NA)
            ]
            if (class(ES_data[[var]]) == "integer") {
                if (ref_cont_covar_event_time == omitted_event_time) {
                    ES_data[, `:=`((varname), as.integer(sum2(get(varname) *
                        (ref_event_time == ref_cont_covar_event_time),
                    na.rm = TRUE
                    ))), by = c(
                        unit_var, "ref_onset_time",
                        "catt_specific_sample"
                    )]
                } else {
                    ES_data[, `:=`((varname), as.integer(sum2(get(varname) *
                        (ref_event_time == ref_cont_covar_event_time),
                    na.rm = TRUE
                    ))), by = c(unit_var, "ref_onset_time")]
                }
            } else {
                if (ref_cont_covar_event_time == omitted_event_time) {
                    ES_data[, `:=`((varname), as.numeric(sum2(get(varname) *
                        (ref_event_time == ref_cont_covar_event_time),
                    na.rm = TRUE
                    ))), by = c(
                        unit_var, "ref_onset_time",
                        "catt_specific_sample"
                    )]
                } else {
                    ES_data[, `:=`((varname), as.numeric(sum2(get(varname) *
                        (ref_event_time == ref_cont_covar_event_time),
                    na.rm = TRUE
                    ))), by = c(unit_var, "ref_onset_time")]
                }
            }
        }
        ref_cont_covars <- unique(na.omit(c(
            ref_cont_covars,
            setdiff(colnames(ES_data), start_cols)
        )))
        rm(start_cols)
    }
    if (!is.null(ref_reg_weights)) {
        if (class(ES_data[[ref_reg_weights]]) == "integer") {
            ES_data[, `:=`((ref_reg_weights), as.numeric(get(ref_reg_weights)))]
        }
        ES_data[, `:=`((ref_reg_weights), unique(get(ref_reg_weights)[ref_event_time ==
            ref_reg_weights_event_time], na.rm = TRUE)[1]),
        by = c(unit_var, "ref_onset_time")
        ]
        count_ES_initial <- dim(ES_data)[1]
        count_ref_reg_weights_isna <- dim(ES_data[is.na(get(ref_reg_weights))])[1]
        count_ref_reg_weights_isinf <- dim(ES_data[is.infinite(get(ref_reg_weights))])[1]
        count_ref_reg_weights_nonpositive <- dim(ES_data[get(ref_reg_weights) <=
            0])[1]
        ES_data <- ES_data[(!is.na(get(ref_reg_weights))) &
            (!is.infinite(get(ref_reg_weights))) & (get(ref_reg_weights) >
            0)]
        gc()
        count_dropped <- (count_ES_initial - dim(ES_data)[1])
        rm(count_ES_initial)
        if (count_dropped != 0) {
            flog.info((sprintf(
                "\n Warning: Droppped %s from stacked data due to missing or extreme ref_reg_weights. \n Of those dropped, the breakdown is: \n 1) %s%% had NA ref_reg_weights \n 2) %s%% had Inf/-Inf ref_reg_weights \n 3) %s%% had ref_reg_weights <= 0",
                format(count_dropped, scientific = FALSE, big.mark = ","),
                round(((count_ref_reg_weights_isna / count_dropped) *
                    100), digits = 4), round(((count_ref_reg_weights_isinf / count_dropped) *
                    100), digits = 4), round(((count_ref_reg_weights_nonpositive / count_dropped) *
                    100), digits = 4)
            )))
        }
        rm(
            count_ref_reg_weights_isna, count_ref_reg_weights_isinf,
            count_ref_reg_weights_nonpositive, count_dropped
        )
        gc()
    }
    if (ipw == TRUE) {
        ES_data[, `:=`(sortorder, .I)]
        ES_data[, `:=`(did_id, .GRP), by = list(
            ref_onset_time,
            catt_specific_sample
        )]
        setorderv(ES_data, "sortorder")
        ES_data[, `:=`(sortorder, NULL)]
        if (ipw_composition_change == FALSE) {
            ES_data[, `:=`(pr_temp, as.numeric(NA))]
            for (i in did_ids) {
                ipw_dt <- ES_make_ipw_dt(
                    did_dt = copy(ES_data[did_id ==
                        i]), unit_var = unit_var, cal_time_var = cal_time_var,
                    discrete_covars = discrete_covars, cont_covars = cont_covars,
                    ref_discrete_covars = ref_discrete_covars,
                    ref_cont_covars = ref_cont_covars, omitted_event_time = omitted_event_time,
                    ipw_model = ipw_model, reg_weights = reg_weights,
                    ref_reg_weights = ref_reg_weights, ipw_composition_change = ipw_composition_change
                )
                ipw_dt[, `:=`(did_id, i)]
                ES_data <- merge(ES_data, ipw_dt, by = c(
                    unit_var,
                    cal_time_var, "did_id"
                ), all.x = TRUE, sort = FALSE)
                ES_data[is.na(pr_temp) & !is.na(pr), `:=`(
                    pr_temp,
                    pr
                )]
                ES_data[, `:=`(pr, NULL)]
                ipw_dt <- NULL
                gc()
            }
            setnames(ES_data, "pr_temp", "pr")
            count_ES_initial <- dim(ES_data)[1]
            count_pr_isna <- dim(ES_data[is.na(pr)])[1]
            count_pr_isinf <- dim(ES_data[is.infinite(pr)])[1]
            count_pr_extremelow <- dim(ES_data[pr <= ipw_ps_lower_bound])[1]
            count_pr_extremehigh <- dim(ES_data[pr >= ipw_ps_upper_bound])[1]
            ES_data <- ES_data[!is.na(pr) & !is.infinite(pr) &
                (between(pr, ipw_ps_lower_bound, ipw_ps_upper_bound,
                    incbounds = FALSE
                ))]
            gc()
            count_dropped <- (count_ES_initial - dim(ES_data)[1])
            rm(count_ES_initial)
            if (count_dropped != 0) {
                flog.info((sprintf(
                    "\n Warning: Droppped %s from stacked data due to missing or extreme estimated propensity scores. \n Of those dropped, the breakdown is: \n 1) %s%% had a NA propensity score \n 2) %s%% had a Inf/-Inf propensity score \n 3) %s%% had a propensity score <= %s \n 4) %s%% had a propensity score >= %s",
                    format(count_dropped,
                        scientific = FALSE,
                        big.mark = ","
                    ), round(((count_pr_isna / count_dropped) *
                        100), digits = 4), round(((count_pr_isinf / count_dropped) *
                        100), digits = 4), round(((count_pr_extremelow / count_dropped) *
                        100), digits = 4), ipw_ps_lower_bound, round(((count_pr_extremehigh / count_dropped) *
                        100), digits = 4), ipw_ps_upper_bound
                )))
            }
            rm(
                count_pr_isna, count_pr_isinf, count_pr_extremelow,
                count_pr_extremehigh, count_dropped
            )
            ES_data[treated == 1, `:=`(pw, 1)]
            ES_data[treated == 0, `:=`(pw, (pr / (1 - pr)))]
            if (ipw_keep_data == TRUE) {
                ipw_dt <- ES_data[, list(
                    get(unit_var), ref_onset_time,
                    ref_event_time, catt_specific_sample, treated,
                    pr
                )]
                setnames(ipw_dt, "V1", unit_var)
            }
        }
    }

    if (return_clean_data) {
        flog.info("Since you specified return_clean_data=TRUE, stopping estimation and returning the cleaned data.table object that would have been used directly in ATT estimation.")
        return(ES_data)
    }

    if (homogeneous_ATT == FALSE) {
        ES_results_hetero_rf <- by_cohort_ES_estimate_ATT(
            ES_data = copy(ES_data),
            outcomevar = outcomevar, unit_var = unit_var, onset_time_var = onset_time_var,
            cluster_vars = cluster_vars, homogeneous_ATT = homogeneous_ATT,
            omitted_event_time = omitted_event_time, discrete_covars = discrete_covars,
            cont_covars = cont_covars, ref_discrete_covars = ref_discrete_covars,
            ref_cont_covars = ref_cont_covars, residualize_covariates = residualize_covariates,
            reg_weights = reg_weights, ref_reg_weights = ref_reg_weights,
            ipw = ipw, ipw_composition_change = ipw_composition_change,
            add_unit_fes = add_unit_fes, cohort_by_cohort = cohort_by_cohort,
            cohort_by_cohort_num_cores = cohort_by_cohort_num_cores
        )
        gc()
        catt_coefs_rf <- ES_results_hetero_rf[[2]]
        catt_vcov_rf <- ES_results_hetero_rf[[3]]
        ES_results_hetero_rf <- ES_results_hetero_rf[[1]]
        gc()
        setnames(
            ES_results_hetero_rf, c(onset_time_var, "event_time"),
            c("ref_onset_time", "ref_event_time")
        )
        if (!is.null(endog_var)) {
            ES_results_hetero_fs <- by_cohort_ES_estimate_ATT(
                ES_data = copy(ES_data),
                outcomevar = endog_var, unit_var = unit_var,
                onset_time_var = onset_time_var, cluster_vars = cluster_vars,
                homogeneous_ATT = homogeneous_ATT, omitted_event_time = omitted_event_time,
                discrete_covars = discrete_covars, cont_covars = cont_covars,
                ref_discrete_covars = ref_discrete_covars, ref_cont_covars = ref_cont_covars,
                residualize_covariates = residualize_covariates,
                reg_weights = reg_weights, ref_reg_weights = ref_reg_weights,
                ipw = ipw, ipw_composition_change = ipw_composition_change,
                add_unit_fes = add_unit_fes, cohort_by_cohort = cohort_by_cohort,
                cohort_by_cohort_num_cores = cohort_by_cohort_num_cores
            )
            gc()
            catt_coefs_fs <- ES_results_hetero_fs[[2]]
            catt_vcov_fs <- ES_results_hetero_fs[[3]]
            ES_results_hetero_fs <- ES_results_hetero_fs[[1]]
            gc()
            setnames(ES_results_hetero_fs, c(
                onset_time_var,
                "event_time"
            ), c("ref_onset_time", "ref_event_time"))
        } else {
            ES_results_hetero_fs <- NULL
        }
    } else {
        ES_results_hetero_rf <- NULL
        ES_results_hetero_fs <- NULL
    }
    if (heterogeneous_only == FALSE) {
        ES_results_homo_rf <- by_cohort_ES_estimate_ATT(
            ES_data = copy(ES_data),
            outcomevar = outcomevar, unit_var = unit_var, onset_time_var = onset_time_var,
            cluster_vars = cluster_vars, homogeneous_ATT = TRUE,
            omitted_event_time = omitted_event_time, discrete_covars = discrete_covars,
            cont_covars = cont_covars, ref_discrete_covars = ref_discrete_covars,
            ref_cont_covars = ref_cont_covars, residualize_covariates = residualize_covariates,
            reg_weights = reg_weights, ref_reg_weights = ref_reg_weights,
            ipw = ipw, ipw_composition_change = ipw_composition_change,
            add_unit_fes = add_unit_fes, cohort_by_cohort = cohort_by_cohort,
            cohort_by_cohort_num_cores = cohort_by_cohort_num_cores
        )[[1]]
        gc()
        setnames(
            ES_results_homo_rf, c(onset_time_var, "event_time"),
            c("ref_onset_time", "ref_event_time")
        )
        if (!is.null(endog_var)) {
            ES_results_homo_fs <- by_cohort_ES_estimate_ATT(
                ES_data = copy(ES_data),
                outcomevar = endog_var, unit_var = unit_var,
                onset_time_var = onset_time_var, cluster_vars = cluster_vars,
                homogeneous_ATT = TRUE, omitted_event_time = omitted_event_time,
                discrete_covars = discrete_covars, cont_covars = cont_covars,
                ref_discrete_covars = ref_discrete_covars, ref_cont_covars = ref_cont_covars,
                residualize_covariates = residualize_covariates,
                reg_weights = reg_weights, ref_reg_weights = ref_reg_weights,
                ipw = ipw, ipw_composition_change = ipw_composition_change,
                add_unit_fes = add_unit_fes, cohort_by_cohort = cohort_by_cohort,
                cohort_by_cohort_num_cores = cohort_by_cohort_num_cores
            )[[1]]
            gc()
            setnames(
                ES_results_homo_fs, c(onset_time_var, "event_time"),
                c("ref_onset_time", "ref_event_time")
            )
        } else {
            ES_results_homo_fs <- NULL
        }
    } else {
        ES_results_homo_rf <- NULL
        ES_results_homo_fs <- NULL
    }
    if (!(is.null(reg_weights)) | !(is.null(ref_reg_weights))) {
        ES_treatcontrol_means_rf <- ES_data[
            , list(
                rn = "treatment_means",
                estimate = weighted.mean(get(outcomevar), get(na.omit(unique(c(
                    reg_weights,
                    ref_reg_weights
                ))))), cluster_se = sd(get(outcomevar)) / sqrt(.N)
            ),
            list(ref_onset_time, ref_event_time, treated)
        ][order(
            ref_onset_time,
            ref_event_time, treated
        )]
        if (!is.null(endog_var)) {
            ES_treatcontrol_means_fs <- ES_data[
                , list(
                    rn = "treatment_means",
                    estimate = weighted.mean(get(endog_var), get(na.omit(unique(c(
                        reg_weights,
                        ref_reg_weights
                    ))))), cluster_se = sd(get(endog_var)) / sqrt(.N)
                ),
                list(ref_onset_time, ref_event_time, treated)
            ][order(
                ref_onset_time,
                ref_event_time, treated
            )]
        } else {
            ES_treatcontrol_means_fs <- NULL
        }
    } else {
        ES_treatcontrol_means_rf <- ES_data[
            , list(
                rn = "treatment_means",
                estimate = mean(get(outcomevar)), cluster_se = sd(get(outcomevar)) / sqrt(.N)
            ),
            list(ref_onset_time, ref_event_time, treated)
        ][order(
            ref_onset_time,
            ref_event_time, treated
        )]
        if (!is.null(endog_var)) {
            ES_treatcontrol_means_fs <- ES_data[
                , list(
                    rn = "treatment_means",
                    estimate = mean(get(endog_var)), cluster_se = sd(get(endog_var)) / sqrt(.N)
                ),
                list(ref_onset_time, ref_event_time, treated)
            ][order(
                ref_onset_time,
                ref_event_time, treated
            )]
        } else {
            ES_treatcontrol_means_fs <- NULL
        }
    }
    means_rf <- ES_data[ref_event_time == omitted_event_time,
        list(mean_outcome = mean(get(outcomevar))),
        by = list(
            ref_onset_time,
            catt_specific_sample
        )
    ][order(ref_onset_time, catt_specific_sample)]
    means_rf[, `:=`(pooled_mean, ES_data[
        ref_event_time == omitted_event_time,
        mean(get(outcomevar))
    ])]
    treat_means_rf <- ES_data[ref_event_time == omitted_event_time &
        treated == 1, list(treat_mean_outcome = mean(get(outcomevar))),
    by = list(ref_onset_time, catt_specific_sample)
    ][order(
        ref_onset_time,
        catt_specific_sample
    )]
    treat_means_rf[, `:=`(treat_pooled_mean, ES_data[ref_event_time ==
        omitted_event_time & treated == 1, mean(get(outcomevar))])]
    mapping <- ES_data[, .N, by = list(
        ref_onset_time, catt_specific_sample,
        ref_event_time
    )][order(
        ref_onset_time, catt_specific_sample,
        ref_event_time
    )]
    mapping <- mapping[ref_event_time != omitted_event_time]
    means_rf <- merge(means_rf, mapping, by = c(
        "ref_onset_time",
        "catt_specific_sample"
    ), all.x = TRUE)
    treat_means_rf <- merge(treat_means_rf, mapping, by = c(
        "ref_onset_time",
        "catt_specific_sample"
    ), all.x = TRUE)
    if (!is.null(ref_event_time_mean_vars)) {
        req_varmeans1 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time1 & treated == 0,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_varmeans1, setdiff(
            colnames(req_varmeans1),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "mean_",
            setdiff(colnames(req_varmeans1), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_treat_varmeans1 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time1 & treated == 1,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_treat_varmeans1, setdiff(
            colnames(req_treat_varmeans1),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "treat_mean_",
            setdiff(colnames(req_treat_varmeans1), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_varmeans1 <- merge(req_varmeans1, mapping, by = c(
            "ref_onset_time",
            "catt_specific_sample"
        ), all.x = TRUE)
        req_treat_varmeans1 <- merge(req_treat_varmeans1, mapping,
            by = c("ref_onset_time", "catt_specific_sample"),
            all.x = TRUE
        )

        req_varmeans2 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time2 & treated == 0,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_varmeans2, setdiff(
            colnames(req_varmeans2),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "mean_",
            setdiff(colnames(req_varmeans2), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_treat_varmeans2 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time2 & treated == 1,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_treat_varmeans2, setdiff(
            colnames(req_treat_varmeans2),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "treat_mean_",
            setdiff(colnames(req_treat_varmeans2), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_varmeans2 <- merge(req_varmeans2, mapping, by = c(
            "ref_onset_time",
            "catt_specific_sample"
        ), all.x = TRUE)
        req_treat_varmeans2 <- merge(req_treat_varmeans2, mapping,
            by = c("ref_onset_time", "catt_specific_sample"),
            all.x = TRUE
        )

        req_varmeans3 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time3 & treated == 0,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_varmeans3, setdiff(
            colnames(req_varmeans3),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "mean_",
            setdiff(colnames(req_varmeans3), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_treat_varmeans3 <- ES_data[ref_event_time == ref_event_time_mean_var_event_time3 & treated == 1,
            lapply(.SD, mean, na.rm = TRUE),
            by = list(
                ref_onset_time,
                catt_specific_sample
            ), .SDcols = ref_event_time_mean_vars
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        setnames(req_treat_varmeans3, setdiff(
            colnames(req_treat_varmeans3),
            c("ref_onset_time", "catt_specific_sample")
        ), paste0(
            "treat_mean_",
            setdiff(colnames(req_treat_varmeans3), c(
                "ref_onset_time",
                "catt_specific_sample"
            ))
        ))
        req_varmeans3 <- merge(req_varmeans3, mapping, by = c(
            "ref_onset_time",
            "catt_specific_sample"
        ), all.x = TRUE)
        req_treat_varmeans3 <- merge(req_treat_varmeans3, mapping,
            by = c("ref_onset_time", "catt_specific_sample"),
            all.x = TRUE
        )
    }
    if (!is.null(endog_var)) {
        means_fs <- ES_data[ref_event_time == omitted_event_time,
            list(mean_outcome = mean(get(endog_var))),
            by = list(
                ref_onset_time,
                catt_specific_sample
            )
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        means_fs[, `:=`(pooled_mean, ES_data[ref_event_time ==
            omitted_event_time, mean(get(endog_var))])]
        treat_means_fs <- ES_data[ref_event_time == omitted_event_time &
            treated == 1, list(treat_mean_outcome = mean(get(endog_var))),
        by = list(ref_onset_time, catt_specific_sample)
        ][order(
            ref_onset_time,
            catt_specific_sample
        )]
        treat_means_fs[, `:=`(treat_pooled_mean, ES_data[ref_event_time ==
            omitted_event_time & treated == 1, mean(get(endog_var))])]
        means_fs <- merge(means_fs, mapping, by = c(
            "ref_onset_time",
            "catt_specific_sample"
        ), all.x = TRUE)
        treat_means_fs <- merge(treat_means_fs, mapping, by = c(
            "ref_onset_time",
            "catt_specific_sample"
        ), all.x = TRUE)
    } else {
        means_fs <- NULL
        treat_means_fs <- NULL
    }

    if (!(is.null(reg_weights)) | !(is.null(ref_reg_weights))) {
        catt_treated_unique_units <- ES_data[treated == 1 &
            (ref_event_time != omitted_event_time), list(catt_treated_unique_units = sum(get(na.omit(unique(c(
            reg_weights,
            ref_reg_weights
        )))))), by = list(
            ref_onset_time,
            ref_event_time
        )][order(ref_onset_time, ref_event_time)]
    } else {
        catt_treated_unique_units <- ES_data[treated == 1 &
            (ref_event_time != omitted_event_time), list(catt_treated_unique_units = .N),
        by = list(ref_onset_time, ref_event_time)
        ][order(
            ref_onset_time,
            ref_event_time
        )]
    }

    if (!(is.null(reg_weights)) | !(is.null(ref_reg_weights))) {
        catt_total_unique_units <- ES_data[ref_event_time !=
            omitted_event_time, list(catt_total_unique_units = sum(get(na.omit(unique(c(
            reg_weights,
            ref_reg_weights
        )))))), by = list(
            ref_onset_time,
            ref_event_time
        )][order(ref_onset_time, ref_event_time)]
    } else {
        catt_total_unique_units <- ES_data[ref_event_time !=
            omitted_event_time, list(catt_total_unique_units = .N),
        by = list(ref_onset_time, ref_event_time)
        ][order(
            ref_onset_time,
            ref_event_time
        )]
    }

    event_time_total_unique_units <- ES_data[ref_event_time !=
        omitted_event_time, list(event_time_total_unique_units = uniqueN(get(unit_var),
        na.rm = TRUE
    )), by = list(ref_event_time)][order(ref_event_time)]
    if (calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE) {
        collapse_input_dt <- copy(collapse_inputs)
        setnames(collapse_input_dt, c("name", "event_times"))
        collapsed_estimate_total_unique_units <- list()
        i <- 0
        for (g in unique(na.omit(collapse_input_dt[["name"]]))) {
            i <- i + 1
            count <- ES_data[(ref_event_time != omitted_event_time) &
                (ref_event_time %in% unique(na.omit(unlist(collapse_input_dt[name ==
                    g][[2]])))), uniqueN(get(unit_var), na.rm = TRUE)]
            collapsed_estimate_total_unique_units[[i]] <- data.table(
                grouping = g,
                collapsed_estimate_total_unique_units = count
            )
        }
        rm(i)
        collapsed_estimate_total_unique_units <- rbindlist(collapsed_estimate_total_unique_units,
            use.names = TRUE
        )
    }
    unique_units <- ES_data[, uniqueN(get(unit_var), na.rm = TRUE)]
    gc()
    ES_data <- NULL
    gc()
    figdata_rf <- rbindlist(list(
        ES_results_hetero_rf, ES_results_homo_rf,
        ES_treatcontrol_means_rf
    ), use.names = TRUE, fill = TRUE)
    rm(ES_results_hetero_rf, ES_results_homo_rf, ES_treatcontrol_means_rf)
    figdata_rf[, `:=`(model, "rf")]
    if (!is.null(endog_var)) {
        figdata_fs <- rbindlist(list(
            ES_results_hetero_fs, ES_results_homo_fs,
            ES_treatcontrol_means_fs
        ), use.names = TRUE, fill = TRUE)
        rm(ES_results_hetero_fs, ES_results_homo_fs, ES_treatcontrol_means_fs)
        figdata_fs[, `:=`(model, "fs")]
    } else {
        figdata_fs <- NULL
    }
    figdata <- rbindlist(list(figdata_rf, figdata_fs),
        use.names = TRUE,
        fill = TRUE
    )
    rm(figdata_rf)
    rm(figdata_fs)
    figdata[, `:=`(ref_onset_time, as.character(ref_onset_time))]
    catt_treated_unique_units[, `:=`(ref_onset_time, as.character(ref_onset_time))]
    catt_total_unique_units[, `:=`(ref_onset_time, as.character(ref_onset_time))]
    figdata <- merge(figdata, catt_treated_unique_units, by = c(
        "ref_onset_time",
        "ref_event_time"
    ), all.x = TRUE, sort = FALSE)
    figdata <- merge(figdata, catt_total_unique_units, by = c(
        "ref_onset_time",
        "ref_event_time"
    ), all.x = TRUE, sort = FALSE)
    subsets_for_avgs <- figdata[rn %in% c("catt")]
    subsets_for_avgs[, `:=`(unweighted_estimate, mean(estimate,
        na.rm = TRUE
    )), by = list(ref_event_time, model)]
    subsets_for_avgs[, `:=`(cohort_weight_V1, catt_treated_unique_units / sum(catt_treated_unique_units,
        na.rm = TRUE
    )), by = list(ref_event_time, model)]
    subsets_for_avgs[, `:=`(cohort_weight_V2, catt_total_unique_units / sum(catt_total_unique_units,
        na.rm = TRUE
    )), by = list(ref_event_time, model)]
    subsets_for_avgs[, `:=`(weighted_estimate_V1, weighted.mean(
        x = estimate,
        w = cohort_weight_V1, na.rm = TRUE
    )), by = list(
        ref_event_time,
        model
    )]
    subsets_for_avgs[, `:=`(weighted_estimate_V2, weighted.mean(
        x = estimate,
        w = cohort_weight_V2, na.rm = TRUE
    )), by = list(
        ref_event_time,
        model
    )]
    weights <- subsets_for_avgs[, list(
        ref_event_time, ref_onset_time,
        model, cohort_weight_V1, cohort_weight_V2
    )]
    figdata <- merge(figdata, weights, by = c(
        "ref_onset_time",
        "ref_event_time", "model"
    ), all.x = TRUE, sort = FALSE)
    subsets_for_avgs[, `:=`(rowid, seq_len(.N)), by = list(
        ref_event_time,
        model
    )]
    subsets_for_avgs <- subsets_for_avgs[rowid == 1 | is.na(rowid)]
    unweighted <- subsets_for_avgs[, list(
        ref_event_time, unweighted_estimate,
        model
    )]
    unweighted[, `:=`(ref_onset_time, "Equally-Weighted")]
    unweighted[, `:=`(rn, "att")]
    setnames(unweighted, c("unweighted_estimate"), c("estimate"))
    setorderv(unweighted, c("ref_event_time", "model"))
    weighted_V1 <- subsets_for_avgs[, list(
        ref_event_time, weighted_estimate_V1,
        model
    )]
    weighted_V1[, `:=`(ref_onset_time, "Cohort-Weighted")]
    weighted_V1[, `:=`(rn, "att")]
    setnames(weighted_V1, c("weighted_estimate_V1"), c("estimate"))
    setorderv(weighted_V1, c("ref_event_time", "model"))
    weighted_V2 <- subsets_for_avgs[, list(
        ref_event_time, weighted_estimate_V2,
        model
    )]
    weighted_V2[, `:=`(ref_onset_time, "Cohort-Weighted V2")]
    weighted_V2[, `:=`(rn, "att")]
    setnames(weighted_V2, c("weighted_estimate_V2"), c("estimate"))
    setorderv(weighted_V2, c("ref_event_time", "model"))
    figdata <- rbindlist(list(
        figdata, unweighted, weighted_V1,
        weighted_V2
    ), use.names = TRUE, fill = TRUE)
    figdata[is.na(cluster_se), `:=`(cluster_se, 0)]
    rm(subsets_for_avgs)
    rm(weights)
    rm(unweighted)
    rm(weighted_V1)
    rm(weighted_V2)
    rm(catt_treated_unique_units)
    rm(catt_total_unique_units)
    gc()

    if (homogeneous_ATT == FALSE) {
        event_times <- setdiff(
            figdata[, sort(unique(ref_event_time))],
            omitted_event_time
        )
        onset_times <- as.integer(figdata[rn == "catt", sort(unique(ref_onset_time))])
        min_onset_time <- min(onset_times)
        max_onset_time <- max(onset_times)
        if (!is.null(endog_var)) {
            model_vals <- c("rf", "fs")
        } else {
            model_vals <- "rf"
        }
        for (mm in model_vals) {
            for (et in event_times) {
                if (et < 0) {
                    lookfor <- sprintf("cattlead%s$", abs(et))
                } else {
                    lookfor <- sprintf("catt%s$", abs(et))
                }
                coef_indices <- grep(lookfor, names(get(sprintf(
                    "catt_coefs_%s",
                    mm
                ))))
                rm(lookfor)
                temp <- as.data.table(do.call(cbind, list(get(sprintf(
                    "catt_coefs_%s",
                    mm
                ))[coef_indices], coef_indices)), keep.rownames = TRUE)
                setnames(temp, c("V1", "V2"), c(
                    "estimate",
                    "coef_index"
                ))
                rm(coef_indices)
                temp[, `:=`(estimate, NULL)]
                temp[, `:=`(rn, gsub("lead", "-", rn))]
                for (c in min_onset_time:max_onset_time) {
                    temp[grepl(sprintf(
                        "ref\\_onset\\_time%s",
                        c
                    ), rn), `:=`(ref_onset_time, c)]
                    temp[grepl(sprintf(
                        "ref\\_onset\\_time%s",
                        c
                    ), rn), `:=`(rn, gsub(sprintf(
                        "ref\\_onset\\_time%s\\_catt",
                        c
                    ), "catt", rn))]
                }
                temp[grepl("catt", rn), `:=`(
                    ref_event_time,
                    as.integer(gsub("catt", "", rn))
                )]
                temp[, `:=`(rn, NULL)]
                temp[, `:=`(ref_onset_time, as.character(ref_onset_time))]
                temp <- merge(temp, figdata[rn == "catt" & model ==
                    mm],
                by = c("ref_onset_time", "ref_event_time"),
                all.x = TRUE, sort = FALSE
                )
                temp <- temp[, list(
                    ref_onset_time, ref_event_time,
                    coef_index, cohort_weight_V1, cohort_weight_V2
                )]
                temp[, `:=`(equal_weight, 1 / .N)]
                temp[, `:=`(equal_w_formula_entry, sprintf(
                    "(%s*x%s)",
                    equal_weight, coef_index
                ))]
                temp[, `:=`(cohort_w_v1_formula_entry, sprintf(
                    "(%s*x%s)",
                    cohort_weight_V1, coef_index
                ))]
                temp[, `:=`(cohort_w_v2_formula_entry, sprintf(
                    "(%s*x%s)",
                    cohort_weight_V2, coef_index
                ))]
                equal_w_g_formula_input <- paste0(temp$equal_w_formula_entry,
                    collapse = "+"
                )
                cohort_w_v1_g_formula_input <- paste0(temp$cohort_w_v1_formula_entry,
                    collapse = "+"
                )
                cohort_w_v2_g_formula_input <- paste0(temp$cohort_w_v2_formula_entry,
                    collapse = "+"
                )
                figdata[rn == "att" & model == mm & cluster_se ==
                    0 & ref_event_time == et & ref_onset_time ==
                    "Equally-Weighted", `:=`(cluster_se, delta_method(
                    g = as.formula(paste(
                        "~",
                        equal_w_g_formula_input
                    )), mean = get(sprintf(
                        "catt_coefs_%s",
                        mm
                    )), cov = get(sprintf("catt_vcov_%s", mm)),
                    ses = TRUE
                ))]
                figdata[rn == "att" & model == mm & cluster_se ==
                    0 & ref_event_time == et & ref_onset_time ==
                    "Cohort-Weighted", `:=`(cluster_se, delta_method(
                    g = as.formula(paste(
                        "~",
                        cohort_w_v1_g_formula_input
                    )), mean = get(sprintf(
                        "catt_coefs_%s",
                        mm
                    )), cov = get(sprintf("catt_vcov_%s", mm)),
                    ses = TRUE
                ))]
                figdata[rn == "att" & model == mm & cluster_se ==
                    0 & ref_event_time == et & ref_onset_time ==
                    "Cohort-Weighted V2", `:=`(cluster_se, delta_method(
                    g = as.formula(paste(
                        "~",
                        cohort_w_v2_g_formula_input
                    )), mean = get(sprintf(
                        "catt_coefs_%s",
                        mm
                    )), cov = get(sprintf("catt_vcov_%s", mm)),
                    ses = TRUE
                ))]
                rm(
                    temp, equal_w_g_formula_input, cohort_w_v1_g_formula_input,
                    cohort_w_v2_g_formula_input
                )
            }
        }
        rm(et, mm)
        gc()
    }
    figdata <- merge(figdata, event_time_total_unique_units,
        by = "ref_event_time", all.x = TRUE, sort = FALSE
    )
    rm(event_time_total_unique_units)

    if (calculate_collapse_estimates == TRUE & homogeneous_ATT == FALSE) {
        ew <- list()
        cw <- list()
        cw2 <- list()
        j <- 0

        for (g in unique(na.omit(collapse_input_dt[["name"]]))) {
            j <- j + 1

            group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name ==
                g][[2]]))), omitted_event_time)
            dt <- figdata[(ref_event_time %in% group_event_times) &
                (rn == "catt")]
            dt[, `:=`(grouping, g)]
            dt[, `:=`(unweighted_estimate, mean(estimate, na.rm = TRUE)),
                by = list(ref_event_time, model)
            ]
            dt[, `:=`(weighted_estimate_V1, weighted.mean(
                x = estimate,
                w = cohort_weight_V1, na.rm = TRUE
            )), by = list(
                ref_event_time,
                model
            )]
            dt[, `:=`(weighted_estimate_V2, weighted.mean(
                x = estimate,
                w = cohort_weight_V2, na.rm = TRUE
            )), by = list(
                ref_event_time,
                model
            )]
            dt[, `:=`(rowid, seq_len(.N)), by = list(
                ref_event_time,
                model
            )]
            result <- dt[rowid == 1 | is.na(rowid)]
            dt <- dt[, list(
                ref_event_time, ref_onset_time,
                model, cohort_weight_V1, cohort_weight_V2, grouping
            )]
            ew[[j]] <- result[, list(grouping, model, unweighted_estimate)]
            ew[[j]][, `:=`(unweighted_estimate, mean(unweighted_estimate,
                na.rm = TRUE
            )), by = list(grouping, model)]
            ew[[j]][, `:=`(ref_onset_time, "Equally-Weighted + Collapsed")]
            ew[[j]][, `:=`(rn, "att")]
            ew[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
            setnames(ew[[j]], c("unweighted_estimate"), c("estimate"))
            ew[[j]] <- ew[[j]][rowid == 1 | is.na(rowid)]
            ew[[j]][, `:=`(rowid, NULL)]
            ew[[j]][, `:=`(cluster_se, 0)]
            cw[[j]] <- result[, list(grouping, model, weighted_estimate_V1)]
            cw[[j]][, `:=`(weighted_estimate_V1, mean(weighted_estimate_V1,
                na.rm = TRUE
            )), by = list(grouping, model)]
            cw[[j]][, `:=`(ref_onset_time, "Cohort-Weighted + Collapsed")]
            cw[[j]][, `:=`(rn, "att")]
            cw[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
            setnames(cw[[j]], c("weighted_estimate_V1"), c("estimate"))
            cw[[j]] <- cw[[j]][rowid == 1 | is.na(rowid)]
            cw[[j]][, `:=`(rowid, NULL)]
            cw[[j]][, `:=`(cluster_se, 0)]
            cw2[[j]] <- result[, list(grouping, model, weighted_estimate_V2)]
            cw2[[j]][, `:=`(weighted_estimate_V2, mean(weighted_estimate_V2,
                na.rm = TRUE
            )), by = list(grouping, model)]
            cw2[[j]][, `:=`(ref_onset_time, "Cohort-Weighted V2 + Collapsed")]
            cw2[[j]][, `:=`(rn, "att")]
            cw2[[j]][, `:=`(rowid, seq_len(.N)), by = list(model)]
            setnames(cw2[[j]], c("weighted_estimate_V2"), c("estimate"))
            cw2[[j]] <- cw2[[j]][rowid == 1 | is.na(rowid)]
            cw2[[j]][, `:=`(rowid, NULL)]
            cw2[[j]][, `:=`(cluster_se, 0)]
            rm(result)
            if (!is.null(endog_var)) {
                model_vals <- c("rf", "fs")
            } else {
                model_vals <- "rf"
            }

            for (mm in model_vals) {
                templist <- list()
                i <- 0
                for (et in group_event_times) {
                    i <- i + 1

                    if (et < 0) {
                        lookfor <- sprintf("cattlead%s$", abs(et))
                    } else {
                        lookfor <- sprintf("catt%s$", abs(et))
                    }
                    coef_indices <- grep(lookfor, names(get(sprintf(
                        "catt_coefs_%s",
                        mm
                    ))))
                    rm(lookfor)
                    temp <- as.data.table(do.call(cbind, list(get(sprintf(
                        "catt_coefs_%s",
                        mm
                    ))[coef_indices], coef_indices)), keep.rownames = TRUE)
                    setnames(temp, c("V1", "V2"), c(
                        "estimate",
                        "coef_index"
                    ))
                    rm(coef_indices)
                    temp[, `:=`(estimate, NULL)]
                    temp[, `:=`(rn, gsub("lead", "-", rn))]
                    for (c in min_onset_time:max_onset_time) {
                        temp[grepl(sprintf(
                            "ref\\_onset\\_time%s",
                            c
                        ), rn), `:=`(ref_onset_time, c)]
                        temp[grepl(sprintf(
                            "ref\\_onset\\_time%s",
                            c
                        ), rn), `:=`(rn, gsub(sprintf(
                            "ref\\_onset\\_time%s\\_catt",
                            c
                        ), "catt", rn))]
                    }
                    temp[grepl("catt", rn), `:=`(
                        ref_event_time,
                        as.integer(gsub("catt", "", rn))
                    )]
                    temp[, `:=`(rn, NULL)]
                    temp[, `:=`(ref_onset_time, as.character(ref_onset_time))]
                    temp <- merge(temp, dt[model == mm], by = c(
                        "ref_onset_time",
                        "ref_event_time"
                    ), all.x = TRUE, sort = FALSE)
                    temp[, `:=`(weight_V0, 1 / .N)]
                    templist[[i]] <- copy(temp)
                    rm(temp)
                    gc()
                }
                rm(i)
                templist <- rbindlist(templist, use.names = TRUE)
                templist[, `:=`(across_weight, (1 / uniqueN(ref_event_time)))]
                templist[, `:=`(full_weight_V0, weight_V0 *
                    across_weight)]
                templist[, `:=`(full_weight_V1, cohort_weight_V1 *
                    across_weight)]
                templist[, `:=`(full_weight_V2, cohort_weight_V2 *
                    across_weight)]
                templist[, `:=`(equal_w_formula_entry, sprintf(
                    "(%s*x%s)",
                    full_weight_V0, coef_index
                ))]
                templist[, `:=`(cohort_w_v1_formula_entry, sprintf(
                    "(%s*x%s)",
                    full_weight_V1, coef_index
                ))]
                templist[, `:=`(cohort_w_v2_formula_entry, sprintf(
                    "(%s*x%s)",
                    full_weight_V2, coef_index
                ))]
                formula_input_ew <- paste0(templist$equal_w_formula_entry,
                    collapse = "+"
                )
                formula_input_cw <- paste0(templist$cohort_w_v1_formula_entry,
                    collapse = "+"
                )
                formula_input_cw2 <- paste0(templist$cohort_w_v2_formula_entry,
                    collapse = "+"
                )
                rm(templist)
                ew[[j]][grouping == g & model == mm, `:=`(
                    cluster_se,
                    delta_method(
                        g = as.formula(paste("~", formula_input_ew)),
                        mean = get(sprintf("catt_coefs_%s", mm)),
                        cov = get(sprintf("catt_vcov_%s", mm)),
                        ses = TRUE
                    )
                )]
                cw[[j]][grouping == g & model == mm, `:=`(
                    cluster_se,
                    delta_method(
                        g = as.formula(paste("~", formula_input_cw)),
                        mean = get(sprintf("catt_coefs_%s", mm)),
                        cov = get(sprintf("catt_vcov_%s", mm)),
                        ses = TRUE
                    )
                )]
                cw2[[j]][grouping == g & model == mm, `:=`(
                    cluster_se,
                    delta_method(
                        g = as.formula(paste("~", formula_input_cw2)),
                        mean = get(sprintf("catt_coefs_%s", mm)),
                        cov = get(sprintf("catt_vcov_%s", mm)),
                        ses = TRUE
                    )
                )]
                rm(formula_input_ew)
                rm(formula_input_cw)
                rm(formula_input_cw2)
            }
            rm(dt)
        }
        rm(j)
        ew <- rbindlist(ew, use.names = TRUE)
        cw <- rbindlist(cw, use.names = TRUE)
        cw2 <- rbindlist(cw2, use.names = TRUE)
        ew <- merge(ew, collapsed_estimate_total_unique_units,
            by = "grouping", all.x = TRUE, sort = FALSE
        )
        cw <- merge(cw, collapsed_estimate_total_unique_units,
            by = "grouping", all.x = TRUE, sort = FALSE
        )
        cw2 <- merge(cw2, collapsed_estimate_total_unique_units,
            by = "grouping", all.x = TRUE, sort = FALSE
        )
        figdata <- rbindlist(list(figdata, ew, cw, cw2),
            use.names = TRUE,
            fill = TRUE
        )
        rm(ew, cw, cw2)
        rm(collapsed_estimate_total_unique_units)
        gc()
    }
    figdata[, `:=`(total_unique_units, unique_units)]
    rm(unique_units)

    if (calculate_collapse_estimates == FALSE) {
        figdata[, `:=`(grouping, NA)]
        figdata[, `:=`(
            collapsed_estimate_total_unique_units,
            NA
        )]
    }

    if (!is.null(endog_var) & homogeneous_ATT == FALSE) {

        ## Calculate cohort-specifc vcov matrix for the ratio estimates
        # => then, calculate SEs for cohort-specific ratios
        catt_vcov_ratio <- makeRatioVcov(
            coefs_fs = copy(catt_coefs_fs),
            coefs_rf = copy(catt_coefs_rf),
            vcov_fs = copy(catt_vcov_fs),
            vcov_rf = copy(catt_vcov_rf)
        )
        catt_coefs_ratio <- catt_coefs_rf / catt_coefs_fs
        catt_ses_ratio <- sqrt(diag(catt_vcov_ratio))
        names(catt_ses_ratio) <- names(catt_coefs_rf)

        catt_ses_ratio <- as.data.table(catt_ses_ratio, keep.rownames = TRUE)
        setnames(catt_ses_ratio, c("catt_ses_ratio"), c("cluster_se"))

        catt_ses_ratio[!grepl("ref\\_onset\\_time", rn), e := min_onset_time]
        catt_ses_ratio[, rn := gsub("lead", "-", rn)]
        for (c in min_onset_time:max_onset_time) {
            catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), e := c]
            catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_et", c), "et", rn)]
            catt_ses_ratio[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
        }
        rm(c)
        catt_ses_ratio[grepl("et", rn), event_time := as.integer(gsub("et", "", rn))]
        catt_ses_ratio[grepl("catt", rn), event_time := as.integer(gsub("catt", "", rn))]
        catt_ses_ratio[grepl("et", rn), rn := "event_time"]
        catt_ses_ratio[grepl("catt", rn), rn := "catt"]
        setnames(catt_ses_ratio, c("e", "event_time"), c("ref_onset_time", "ref_event_time"))
        catt_ses_ratio[, rn := NULL]
        catt_ses_ratio[, ref_onset_time := as.character(ref_onset_time)] # to match 'figdata'

        ## catts: estimates and SEs
        ratio_data_rf <- figdata[rn %in% c("catt") &
            model == "rf", list(
            ref_onset_time,
            ref_event_time,
            estimate,
            rn
        )]
        setnames(ratio_data_rf, "estimate", "estimate_rf")
        ratio_data_fs <- figdata[rn %in% c("catt") &
            model == "fs", list(
            ref_onset_time,
            ref_event_time,
            estimate,
            rn
        )]
        setnames(ratio_data_fs, "estimate", "estimate_fs")
        catt_means_of_ratio <- merge(ratio_data_rf, ratio_data_fs, by = c(
            "ref_onset_time",
            "ref_event_time",
            "rn"
        ), sort = FALSE)
        rm(ratio_data_rf, ratio_data_fs)
        catt_means_of_ratio[, `:=`(estimate, (estimate_rf / estimate_fs))]
        catt_means_of_ratio[, `:=`(c("estimate_rf", "estimate_fs"), NULL)]

        catt_means_of_ratio <- merge(catt_means_of_ratio, catt_ses_ratio,
            by = c(
                "ref_onset_time",
                "ref_event_time"
            ),
            sort = FALSE
        )

        counts_and_weights <- figdata[rn == "catt" &
            model == "rf", list(
            ref_onset_time,
            ref_event_time, rn,
            reg_sample_size,
            catt_treated_unique_units,
            catt_total_unique_units,
            cohort_weight_V1,
            cohort_weight_V2
        )]

        catt_means_of_ratio <- merge(catt_means_of_ratio, counts_and_weights,
            by = c(
                "ref_onset_time",
                "ref_event_time",
                "rn"
            ), sort = FALSE
        )
        rm(counts_and_weights)

        ## event-time atts (e.g. Equal-Weighted, Cohort-Weighted, etc)  (mean of ratios estimates)
        ew <- catt_means_of_ratio[,
            list(
                ref_onset_time = "Equally-Weighted",
                estimate = mean(estimate)
            ),
            by = list(ref_event_time)
        ]
        cw <- catt_means_of_ratio[,
            list(
                ref_onset_time = "Cohort-Weighted",
                estimate = weighted.mean(estimate,
                    w = (catt_treated_unique_units / sum(catt_treated_unique_units))
                )
            ),
            by = list(ref_event_time)
        ]
        cw2 <- catt_means_of_ratio[,
            list(
                ref_onset_time = "Cohort-Weighted V2",
                estimate = weighted.mean(estimate,
                    w = (catt_total_unique_units / sum(catt_total_unique_units))
                )
            ),
            by = list(ref_event_time)
        ]

        event_time_means_of_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
        rm(ew, cw, cw2)

        counts_and_weights <- figdata[
            rn == "att" & is.na(grouping) & model == "rf",
            list(ref_onset_time, ref_event_time, event_time_total_unique_units)
        ]
        event_time_means_of_ratio <- merge(event_time_means_of_ratio, counts_and_weights,
            by = c("ref_onset_time", "ref_event_time"), sort = FALSE
        )

        rm(counts_and_weights)
        event_time_means_of_ratio[, rn := "att"]

        ## event-time atts (e.g. Equal-Weighted, Cohort-Weighted, etc)  (mean of ratios SEs)
        for (et in event_times) {
            model_type <- "ratio"

            if (et < 0) {
                lookfor <- sprintf("cattlead%s$", abs(et))
                # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
            } else {
                lookfor <- sprintf("catt%s$", abs(et))
                # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
            }
            coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
            rm(lookfor)
            temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
            setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
            rm(coef_indices)
            temp[, estimate := NULL]
            temp[, rn := gsub("lead", "-", rn)]
            for (c in min_onset_time:max_onset_time) {
                temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
                temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
            }
            temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
            temp[, rn := NULL]
            temp[, ref_onset_time := as.character(ref_onset_time)]

            # now merge in the weights
            temp <- merge(temp, figdata[rn == "catt" & model == "rf"], by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
            temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
            temp[, equal_weight := 1 / .N]
            temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
            temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

            temp[, equal_w_formula_entry := sprintf("(%s*x%s)", equal_weight, coef_index)]
            temp[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", cohort_weight_V1, coef_index)]
            temp[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", cohort_weight_V2, coef_index)]

            equal_w_g_formula_input <- paste0(temp$equal_w_formula_entry, collapse = "+")
            cohort_w_v1_g_formula_input <- paste0(temp$cohort_w_v1_formula_entry, collapse = "+")
            cohort_w_v2_g_formula_input <- paste0(temp$cohort_w_v2_formula_entry, collapse = "+")

            event_time_means_of_ratio[
                rn == "att" & ref_event_time == et & ref_onset_time == "Equally-Weighted",
                cluster_se := delta_method(
                    g = as.formula(paste("~", equal_w_g_formula_input)),
                    mean = get(sprintf("catt_coefs_%s", model_type)),
                    cov = get(sprintf("catt_vcov_%s", model_type)),
                    ses = TRUE
                )
            ]

            event_time_means_of_ratio[
                rn == "att" & ref_event_time == et & ref_onset_time == "Cohort-Weighted",
                cluster_se := delta_method(
                    g = as.formula(paste("~", cohort_w_v1_g_formula_input)),
                    mean = get(sprintf("catt_coefs_%s", model_type)),
                    cov = get(sprintf("catt_vcov_%s", model_type)),
                    ses = TRUE
                )
            ]

            event_time_means_of_ratio[
                rn == "att" & ref_event_time == et & ref_onset_time == "Cohort-Weighted V2",
                cluster_se := delta_method(
                    g = as.formula(paste("~", cohort_w_v2_g_formula_input)),
                    mean = get(sprintf("catt_coefs_%s", model_type)),
                    cov = get(sprintf("catt_vcov_%s", model_type)),
                    ses = TRUE
                )
            ]

            rm(temp, equal_w_g_formula_input, cohort_w_v1_g_formula_input, cohort_w_v2_g_formula_input)
        }
        rm(et, model_type)
        gc()

        ## collapsed event-time atts
        if (calculate_collapse_estimates == TRUE) {
            ew <- list()
            cw <- list()
            cw2 <- list()
            j <- 0

            # estimates
            for (g in unique(na.omit(collapse_input_dt[["name"]]))) {
                j <- j + 1
                group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name ==
                    g][[2]]))), omitted_event_time)
                ddt <- catt_means_of_ratio[(ref_event_time %in% group_event_times)]

                ddt[, `:=`(grouping, g)]
                ddt[, `:=`(unweighted_estimate, mean(estimate)),
                    by = list(ref_event_time)
                ]
                ddt[, `:=`(weighted_estimate_V1, weighted.mean(estimate,
                    w = (catt_treated_unique_units / sum(catt_treated_unique_units))
                )),
                by = list(ref_event_time)
                ]
                ddt[, `:=`(weighted_estimate_V2, weighted.mean(estimate,
                    w = (catt_total_unique_units / sum(catt_total_unique_units))
                )),
                by = list(ref_event_time)
                ]
                ddt[, `:=`(rowid, seq_len(.N)), by = list(ref_event_time)]
                result <- ddt[rowid == 1 | is.na(rowid)]

                ddt <- ddt[, list(ref_event_time, ref_onset_time, grouping)]

                ew[[j]] <- result[, list(grouping, unweighted_estimate)]
                ew[[j]][, `:=`(unweighted_estimate, mean(unweighted_estimate)), by = list(grouping)]
                ew[[j]][, `:=`(ref_onset_time, "Equally-Weighted + Collapsed")]
                ew[[j]][, `:=`(rn, "att")]
                ew[[j]][, `:=`(rowid, seq_len(.N))]
                setnames(ew[[j]], c("unweighted_estimate"), c("estimate"))
                ew[[j]] <- ew[[j]][rowid == 1 | is.na(rowid)]
                ew[[j]][, `:=`(rowid, NULL)]
                ew[[j]][, `:=`(cluster_se, 0)]

                cw[[j]] <- result[, list(grouping, weighted_estimate_V1)]
                cw[[j]][, `:=`(weighted_estimate_V1, mean(weighted_estimate_V1)), by = list(grouping)]
                cw[[j]][, `:=`(ref_onset_time, "Cohort-Weighted + Collapsed")]
                cw[[j]][, `:=`(rn, "att")]
                cw[[j]][, `:=`(rowid, seq_len(.N))]
                setnames(cw[[j]], c("weighted_estimate_V1"), c("estimate"))
                cw[[j]] <- cw[[j]][rowid == 1 | is.na(rowid)]
                cw[[j]][, `:=`(rowid, NULL)]
                cw[[j]][, `:=`(cluster_se, 0)]

                cw2[[j]] <- result[, list(grouping, weighted_estimate_V2)]
                cw2[[j]][, `:=`(weighted_estimate_V2, mean(weighted_estimate_V2)), by = list(grouping)]
                cw2[[j]][, `:=`(ref_onset_time, "Cohort-Weighted V2 + Collapsed")]
                cw2[[j]][, `:=`(rn, "att")]
                cw2[[j]][, `:=`(rowid, seq_len(.N))]
                setnames(cw2[[j]], c("weighted_estimate_V2"), c("estimate"))
                cw2[[j]] <- cw2[[j]][rowid == 1 | is.na(rowid)]
                cw2[[j]][, `:=`(rowid, NULL)]
                cw2[[j]][, `:=`(cluster_se, 0)]
                rm(result)
                rm(ddt)
            }
            rm(g, j, group_event_times)
            ew <- rbindlist(ew, use.names = TRUE)
            cw <- rbindlist(cw, use.names = TRUE)
            cw2 <- rbindlist(cw2, use.names = TRUE)

            collapsed_means_of_ratio <- rbindlist(list(ew, cw, cw2), use.names = TRUE)
            rm(ew, cw, cw2)

            counts_and_weights <- figdata[
                rn == "att" & !is.na(grouping) & model == "rf",
                list(ref_onset_time, grouping, collapsed_estimate_total_unique_units)
            ]

            collapsed_means_of_ratio <- merge(collapsed_means_of_ratio, counts_and_weights,
                by = c("ref_onset_time", "grouping"), sort = FALSE
            )

            rm(counts_and_weights)
            event_time_means_of_ratio[, rn := "att"]

            # SEs
            for (g in unique(na.omit(collapse_input_dt[["name"]]))) {

                # extract event_times and results corresponding to grouping
                # as we won't have an estimate for the omitted_event_time, exclude it below
                group_event_times <- setdiff(unique(na.omit(unlist(collapse_input_dt[name == g][[2]]))), omitted_event_time)
                ddt <- catt_means_of_ratio[(ref_event_time %in% group_event_times)]
                ddt[, grouping := g]
                ddt[, rowid := seq_len(.N), by = list(ref_event_time)]
                ddt <- ddt[, list(ref_event_time, ref_onset_time, catt_treated_unique_units, catt_total_unique_units, grouping)]

                templist <- list()
                i <- 0
                for (et in group_event_times) {
                    model_type <- "ratio"
                    i <- i + 1

                    if (et < 0) {
                        lookfor <- sprintf("cattlead%s$", abs(et))
                        # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  -1 and -19:-10 event times
                    } else {
                        lookfor <- sprintf("catt%s$", abs(et))
                        # crucial to have the end-of-line anchor "$" above; otherwise will find, e.g.,  1 and 10:19 event times
                    }
                    coef_indices <- grep(lookfor, names(get(sprintf("catt_coefs_%s", model_type))))
                    rm(lookfor)
                    temp <- as.data.table(do.call(cbind, list(get(sprintf("catt_coefs_%s", model_type))[coef_indices], coef_indices)), keep.rownames = TRUE)
                    setnames(temp, c("V1", "V2"), c("estimate", "coef_index"))
                    rm(coef_indices)
                    temp[, estimate := NULL]
                    temp[, rn := gsub("lead", "-", rn)]
                    for (c in min_onset_time:max_onset_time) {
                        temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), ref_onset_time := c]
                        temp[grepl(sprintf("ref\\_onset\\_time%s", c), rn), rn := gsub(sprintf("ref\\_onset\\_time%s\\_catt", c), "catt", rn)]
                    }
                    rm(c)
                    temp[grepl("catt", rn), ref_event_time := as.integer(gsub("catt", "", rn))]
                    temp[, rn := NULL]
                    temp[, ref_onset_time := as.character(ref_onset_time)]

                    # now merge in the within-event-time weights
                    temp <- merge(temp, ddt, by = c("ref_onset_time", "ref_event_time"), all.x = TRUE, sort = FALSE)
                    temp <- temp[, list(ref_onset_time, ref_event_time, coef_index, catt_treated_unique_units, catt_total_unique_units)]
                    temp[, weight_V0 := 1 / .N]
                    temp[, cohort_weight_V1 := catt_treated_unique_units / sum(catt_treated_unique_units)]
                    temp[, cohort_weight_V2 := catt_total_unique_units / sum(catt_total_unique_units)]

                    templist[[i]] <- copy(temp)
                    rm(temp)
                    gc()
                }
                rm(i, et, group_event_times)

                templist <- rbindlist(templist, use.names = TRUE)

                # Now add the across-event-time weights and calculate full (multiplicative) weights
                templist[, across_weight := (1 / uniqueN(ref_event_time))]
                templist[, full_weight_V0 := weight_V0 * across_weight]
                templist[, full_weight_V1 := cohort_weight_V1 * across_weight]
                templist[, full_weight_V2 := cohort_weight_V2 * across_weight]

                templist[, equal_w_formula_entry := sprintf("(%s*x%s)", full_weight_V0, coef_index)]
                templist[, cohort_w_v1_formula_entry := sprintf("(%s*x%s)", full_weight_V1, coef_index)]
                templist[, cohort_w_v2_formula_entry := sprintf("(%s*x%s)", full_weight_V2, coef_index)]

                formula_input_ew <- paste0(templist$equal_w_formula_entry, collapse = "+")
                formula_input_cw <- paste0(templist$cohort_w_v1_formula_entry, collapse = "+")
                formula_input_cw2 <- paste0(templist$cohort_w_v2_formula_entry, collapse = "+")

                rm(templist)

                collapsed_means_of_ratio[
                    grouping == g & ref_onset_time == "Equally-Weighted + Collapsed",
                    cluster_se := delta_method(
                        g = as.formula(paste("~", formula_input_ew)),
                        mean = get(sprintf("catt_coefs_%s", model_type)),
                        cov = get(sprintf("catt_vcov_%s", model_type)),
                        ses = TRUE
                    )
                ]

                collapsed_means_of_ratio[
                    grouping == g & ref_onset_time == "Cohort-Weighted + Collapsed",
                    cluster_se := delta_method(
                        g = as.formula(paste("~", formula_input_cw)),
                        mean = get(sprintf("catt_coefs_%s", model_type)),
                        cov = get(sprintf("catt_vcov_%s", model_type)),
                        ses = TRUE
                    )
                ]

                collapsed_means_of_ratio[
                    grouping == g & ref_onset_time == "Cohort-Weighted V2 + Collapsed",
                    cluster_se := delta_method(
                        g = as.formula(paste("~", formula_input_cw2)),
                        mean = get(sprintf("catt_coefs_%s", model_type)),
                        cov = get(sprintf("catt_vcov_%s", model_type)),
                        ses = TRUE
                    )
                ]

                rm(formula_input_ew)
                rm(formula_input_cw)
                rm(formula_input_cw2)
                rm(ddt)
            }
            rm(g, model_type)
        }

        ratio_data <- rbindlist(list(catt_means_of_ratio, event_time_means_of_ratio, collapsed_means_of_ratio),
            use.names = TRUE,
            fill = TRUE
        )
        ratio_data[, `:=`(model, "ratio")]
        rm(catt_means_of_ratio, event_time_means_of_ratio, collapsed_means_of_ratio)

        figdata <- rbindlist(list(figdata, ratio_data),
            use.names = TRUE,
            fill = TRUE
        )
        rm(ratio_data)
    }
    figdata[model == "rf", `:=`(model, "reduced_form")]
    figdata[model == "fs", `:=`(model, "first_stage")]
    means_rf[, `:=`(model, "reduced_form")]
    treat_means_rf[, `:=`(model, "reduced_form")]
    if (!(is.null(ref_event_time_mean_vars))) {
        req_varmeans1[, `:=`(model, "reduced_form")]
        req_treat_varmeans1[, `:=`(model, "reduced_form")]
        req_varmeans2[, `:=`(model, "reduced_form")]
        req_treat_varmeans2[, `:=`(model, "reduced_form")]
        req_varmeans3[, `:=`(model, "reduced_form")]
        req_treat_varmeans3[, `:=`(model, "reduced_form")]
    }
    if (!is.null(endog_var)) {
        means_fs[, `:=`(model, "first_stage")]
        treat_means_fs[, `:=`(model, "first_stage")]
    }
    if (bootstrapES == TRUE) {
        boot_results <- rbindlist(parallel::mclapply(
            X = 1:bootstrap_iters,
            FUN = Wald_bootstrap_ES, mc.silent = FALSE, mc.cores = bootstrap_num_cores,
            mc.set.seed = TRUE, long_data = original_sample,
            outcomevar = outcomevar, unit_var = unit_var, cal_time_var = cal_time_var,
            onset_time_var = onset_time_var, cluster_vars = cluster_vars,
            omitted_event_time = omitted_event_time, anticipation = anticipation,
            min_control_gap = min_control_gap, max_control_gap = max_control_gap,
            linearize_pretrends = linearize_pretrends, fill_zeros = fill_zeros,
            residualize_covariates = residualize_covariates,
            control_subset_var = control_subset_var, control_subset_event_time = control_subset_event_time,
            treated_subset_var = treated_subset_var, treated_subset_event_time = treated_subset_event_time,
            control_subset_var2 = control_subset_var2, control_subset_event_time2 = control_subset_event_time2,
            treated_subset_var2 = treated_subset_var2, treated_subset_event_time2 = treated_subset_event_time2,
            discrete_covars = discrete_covars, cont_covars = cont_covars,
            never_treat_action = never_treat_action, homogeneous_ATT = homogeneous_ATT,
            reg_weights = reg_weights, add_unit_fes = add_unit_fes,
            ipw = ipw, ipw_model = ipw_model, ipw_composition_change = ipw_composition_change,
            ipw_keep_data = FALSE, ipw_ps_lower_bound = ipw_ps_lower_bound,
            ipw_ps_upper_bound = ipw_ps_upper_bound, event_vs_noevent = event_vs_noevent,
            ref_discrete_covars = ref_discrete_covars, ref_discrete_covar_event_time = ref_discrete_covar_event_time,
            ref_cont_covars = ref_cont_covars, ref_cont_covar_event_time = ref_cont_covar_event_time,
            calculate_collapse_estimates = calculate_collapse_estimates,
            collapse_inputs = collapse_inputs, ref_reg_weights = ref_reg_weights,
            ref_reg_weights_event_time = ref_reg_weights_event_time,
            ntile_var = ntile_var, ntile_event_time = ntile_event_time,
            ntiles = ntiles, ntile_var_value = ntile_var_value,
            ntile_avg = ntile_avg, endog_var = endog_var, cohort_by_cohort = cohort_by_cohort,
            cohort_by_cohort_num_cores = cohort_by_cohort_num_cores,
            heterogeneous_only = heterogeneous_only
        ), use.names = TRUE)
        if (keep_all_bootstrap_results == TRUE) {
            bootstrap_output <- copy(boot_results)
        }
        if (calculate_collapse_estimates == TRUE & homogeneous_ATT ==
            FALSE) {
            boot_results_collapse_estimates <- boot_results[!is.na(grouping)]
            boot_results <- boot_results[is.na(grouping)]
            boot_results[, `:=`(grouping, NULL)]
            gc()
        }
        boot_results_ses <- boot_results[, list(bootstrap_se = sd(estimate)),
            by = list(ref_onset_time, ref_event_time, model)
        ]
        order_to_restore <- na.omit(unique(c(
            copy(colnames(figdata)),
            copy(colnames(boot_results_ses))
        )))
        figdata <- merge(figdata, boot_results_ses, by = c(
            "ref_onset_time",
            "ref_event_time", "model"
        ), all.x = TRUE, sort = FALSE)
        setcolorder(figdata, neworder = order_to_restore)
        rm(order_to_restore)
        figdata[rn == "treatment_means", `:=`(
            bootstrap_se,
            NA
        )]
        if (calculate_collapse_estimates == TRUE & homogeneous_ATT ==
            FALSE) {
            boot_results_collapse_estimates_ses <- boot_results_collapse_estimates[,
                list(bootstrap_se_collapse = sd(estimate)),
                by = list(ref_onset_time, grouping, model)
            ]
            order_to_restore <- na.omit(unique(c(
                copy(colnames(figdata)),
                copy(colnames(boot_results_collapse_estimates_ses))
            )))
            figdata <- merge(figdata, boot_results_collapse_estimates_ses,
                by = c("ref_onset_time", "grouping", "model"),
                all.x = TRUE, sort = FALSE
            )
            setcolorder(figdata, neworder = order_to_restore)
            rm(order_to_restore)
            figdata[!is.na(bootstrap_se_collapse), `:=`(
                bootstrap_se,
                bootstrap_se_collapse
            )]
            figdata[, `:=`(bootstrap_se_collapse, NULL)]
        }
        original_sample <- NULL
        gc()
    }
    if (ipw_keep_data == TRUE) {
        figdata <- rbindlist(list(figdata, ipw_dt),
            use.names = TRUE,
            fill = TRUE
        )
    }
    return_list <- list()
    return_list[[1]] <- figdata
    return_list[[2]] <- list(means_rf, means_fs)
    return_list[[3]] <- list(treat_means_rf, treat_means_fs)
    if (homogeneous_ATT == FALSE) {
        return_list[[4]] <- list(catt_coefs_rf, catt_vcov_rf)
        if (!is.null(endog_var)) {
            return_list[[5]] <- list(catt_coefs_fs, catt_vcov_fs)
        }
    }
    if (ipw_keep_data == TRUE) {
        return_list[[6]] <- ipw_dt
    }
    if (bootstrapES == TRUE & keep_all_bootstrap_results ==
        TRUE) {
        return_list[[7]] <- bootstrap_output
    }
    if (!(is.null(ref_event_time_mean_vars))) {
        return_list[[8]] <- list(req_varmeans1, req_treat_varmeans1, req_varmeans2, req_treat_varmeans2, req_varmeans3, req_treat_varmeans3)
    }

    # Now grab lower thresholds (if ntile_var supplied)
    if (!(is.null(ntile_var))) {
        return_list[[9]] <- copy(bar_y)
    }

    flog.info("Wald_ES is finished.")
    return(return_list)
}
